From 5d6c920e6efdf6e96e9ba13817d0cbd3004164aa Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Thu, 1 Sep 2022 13:45:10 -0400 Subject: [PATCH] Conversion of 'sycl.constructor(%0, %1) {type = @range}' (#51) This PR adds support for converting the `sycl.constructor(%0, %1) {type = @range}` operation (representing construction of a sycl::range object) to a call to the appropriate `sycl::range` constructor. Signed-off-by: Tiotto, Ettore --- .../Conversion/SYCLToLLVM/SYCLFuncRegistry.h | 62 +++- .../SYCLToLLVM/SYCLFuncRegistry.cpp | 278 ++++++++++++------ .../lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp | 40 ++- .../SYCLToLLVM/func-ops-to-llvm.mlir | 156 +++++++++- 4 files changed, 408 insertions(+), 128 deletions(-) diff --git a/mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h b/mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h index e048566414a7f..d3545d0d14b80 100644 --- a/mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h +++ b/mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h @@ -30,6 +30,7 @@ class SYCLFuncRegistry; /// needs to be created in SYCLFuncRegistry constructor. class SYCLFuncDescriptor { friend class SYCLFuncRegistry; + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &, const SYCLFuncDescriptor &); public: /// Enumerates SYCL functions. @@ -44,26 +45,51 @@ class SYCLFuncDescriptor { Id1CtorSizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type) Id2CtorSizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type) Id3CtorSizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type) - Id1CtorRange, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) - Id2CtorRange, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) - Id3CtorRange, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) - Id1CtorItem, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) - Id2CtorItem, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) - Id3CtorItem, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) + Id1Ctor2SizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) + Id2Ctor2SizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) + Id3Ctor2SizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) + Id1Ctor3SizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) + Id2Ctor3SizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) + Id3Ctor3SizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) Id1CopyCtor, // sycl::id<1>::id(sycl::id<1> const&) Id2CopyCtor, // sycl::id<2>::id(sycl::id<2> const&) Id3CopyCtor, // sycl::id<3>::id(sycl::id<3> const&) - // Member functions for ..TODO.. + // Member functions for the sycl::Range class. + Range1CtorDefault, // sycl::Range<1>::range() + Range2CtorDefault, // sycl::range<2>::range() + Range3CtorDefault, // sycl::range<3>::range() + Range1CtorSizeT, // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type) + Range2CtorSizeT, // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type) + Range3CtorSizeT, // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type) + Range1Ctor2SizeT, // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) + Range2Ctor2SizeT, // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) + Range3Ctor2SizeT, // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) + Range1Ctor3SizeT, // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) + Range2Ctor3SizeT, // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) + Range3Ctor3SizeT, // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) + Range1CopyCtor, // sycl::range<1>::range(sycl::range<1> const&) + Range2CopyCtor, // sycl::range<2>::range(sycl::range<2> const&) + Range3CopyCtor, // sycl::range<3>::range(sycl::range<3> const&) }; // clang-format on /// Enumerates the kind of FuncId. enum class FuncIdKind { Unknown, - IdCtor, // any sycl::id constructors + IdCtor, // any sycl::id constructors. + RangeCtor // any sycl::range constructors. }; + /// Returns the funcIdKind given a \p funcId. + static FuncIdKind getFuncIdKind(FuncId funcId); + + /// Retuns a descriptive name for the given \p funcIdKind. + static std::string funcIdKindToName(FuncIdKind funcIdKind); + + /// Retuns the FuncIdKind given a descriptive \p name. + static FuncIdKind nameToFuncIdKind(Twine name); + // Call the SYCL constructor identified by \p id with the given \p args. static Value call(FuncId id, ValueRange args, const SYCLFuncRegistry ®istry, OpBuilder &b, @@ -73,8 +99,11 @@ class SYCLFuncDescriptor { /// Private constructor: only available to 'SYCLFuncRegistry'. SYCLFuncDescriptor(FuncId id, StringRef name, Type outputTy, ArrayRef argTys) - : id(id), name(name), outputTy(outputTy), - argTys(argTys.begin(), argTys.end()) {} + : funcId(id), funcIdKind(getFuncIdKind(id)), name(name), + outputTy(outputTy), argTys(argTys.begin(), argTys.end()) { + assert(funcId != FuncId::Unknown && "Illegal function id"); + assert(funcIdKind != FuncIdKind::Unknown && "Illegal function id kind"); + } /// Inject the declaration for this function into the module. void declareFunction(ModuleOp &module, OpBuilder &b); @@ -83,13 +112,22 @@ class SYCLFuncDescriptor { static bool isIdCtor(FuncId funcId); private: - FuncId id; // unique identifier for a SYCL function + FuncId funcId = FuncId::Unknown; // SYCL function identifier + FuncIdKind funcIdKind = FuncIdKind::Unknown; // SYCL function kind StringRef name; // SYCL function name Type outputTy; // SYCL function output type SmallVector argTys; // SYCL function arguments types - FlatSymbolRefAttr funcRef; // Reference to the SYCL function declaration + FlatSymbolRefAttr funcRef; // Reference to the SYCL function }; +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const SYCLFuncDescriptor &desc) { + os << "funcId=" << (int)desc.funcId + << ", funcIdKind=" << SYCLFuncDescriptor::funcIdKindToName(desc.funcIdKind) + << ", name='" << desc.name.str() << "')"; + return os; +} + /// \class SYCLFuncRegistry /// Singleton class representing the set of SYCL functions callable from the /// compiler. diff --git a/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp b/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp index 1339f24ac898e..879319414825d 100644 --- a/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp +++ b/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp @@ -27,33 +27,64 @@ using namespace mlir::sycl; // SYCLFuncDescriptor //===----------------------------------------------------------------------===// -void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) { - LLVMBuilder builder(b, module.getLoc()); - funcRef = builder.getOrInsertFuncDecl(name, outputTy, argTys, module); -} - -bool SYCLFuncDescriptor::isIdCtor(FuncId funcId) { - switch (funcId) { +SYCLFuncDescriptor::FuncIdKind +SYCLFuncDescriptor::getFuncIdKind(FuncId funcId) { + switch(funcId) { case FuncId::Id1CtorDefault: case FuncId::Id2CtorDefault: case FuncId::Id3CtorDefault: case FuncId::Id1CtorSizeT: case FuncId::Id2CtorSizeT: case FuncId::Id3CtorSizeT: - case FuncId::Id1CtorRange: - case FuncId::Id2CtorRange: - case FuncId::Id3CtorRange: - case FuncId::Id1CtorItem: - case FuncId::Id2CtorItem: - case FuncId::Id3CtorItem: + case FuncId::Id1Ctor2SizeT: + case FuncId::Id2Ctor2SizeT: + case FuncId::Id3Ctor2SizeT: + case FuncId::Id1Ctor3SizeT: + case FuncId::Id2Ctor3SizeT: + case FuncId::Id3Ctor3SizeT: case FuncId::Id1CopyCtor: case FuncId::Id2CopyCtor: case FuncId::Id3CopyCtor: - return true; - default:; + return FuncIdKind::IdCtor; + case FuncId::Range1CtorDefault: + case FuncId::Range2CtorDefault: + case FuncId::Range3CtorDefault: + case FuncId::Range1CtorSizeT: + case FuncId::Range2CtorSizeT: + case FuncId::Range3CtorSizeT: + case FuncId::Range1Ctor2SizeT: + case FuncId::Range2Ctor2SizeT: + case FuncId::Range3Ctor2SizeT: + case FuncId::Range1Ctor3SizeT: + case FuncId::Range2Ctor3SizeT: + case FuncId::Range3Ctor3SizeT: + case FuncId::Range1CopyCtor: + case FuncId::Range2CopyCtor: + case FuncId::Range3CopyCtor: + return FuncIdKind::RangeCtor; + default: + return FuncIdKind::Unknown; + } +} + +std::string SYCLFuncDescriptor::funcIdKindToName(FuncIdKind funcIdKind) { + switch (funcIdKind) { + case FuncIdKind::IdCtor: + return "idCtor"; + case FuncIdKind::RangeCtor: + return "rangeCtor"; + default: + return "unknown"; } +} - return false; +SYCLFuncDescriptor::FuncIdKind +SYCLFuncDescriptor::nameToFuncIdKind(Twine name) { + if (name.str() == "idCtor") + return FuncIdKind::IdCtor; + if (name.str() == "rangeCtor") + return FuncIdKind::RangeCtor; + return FuncIdKind::Unknown; } Value SYCLFuncDescriptor::call(FuncId funcId, ValueRange args, @@ -77,6 +108,11 @@ Value SYCLFuncDescriptor::call(FuncId funcId, ValueRange args, return callOp.getResult(0); } +void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) { + LLVMBuilder builder(b, module.getLoc()); + funcRef = builder.getOrInsertFuncDecl(name, outputTy, argTys, module); +} + //===----------------------------------------------------------------------===// // SYCLFuncRegistry //===----------------------------------------------------------------------===// @@ -96,40 +132,38 @@ SYCLFuncRegistry::getFuncId(SYCLFuncDescriptor::FuncIdKind funcIdKind, Type retType, TypeRange argTypes) const { assert(funcIdKind != SYCLFuncDescriptor::FuncIdKind::Unknown && "Invalid funcIdKind"); - - // Determines whether the given funcId has kind that matches the given - // funcIdKind. - auto kindMatches = [](SYCLFuncDescriptor::FuncId funcId, - SYCLFuncDescriptor::FuncIdKind funcIdKind) { - bool foundMatch = false; - switch (funcIdKind) { - case SYCLFuncDescriptor::FuncIdKind::IdCtor: - foundMatch = SYCLFuncDescriptor::isIdCtor(funcId); - break; - default: - foundMatch = false; - } - return foundMatch; - }; + LLVM_DEBUG(llvm::dbgs() << "Looking up function of kind: " + << SYCLFuncDescriptor::funcIdKindToName(funcIdKind) + << "\n";); for (const auto &entry : registry) { + const SYCLFuncDescriptor &desc = entry.second; + LLVM_DEBUG(llvm::dbgs() << desc << "\n"); + // Skip through entries that do not match the requested funcIdKind. - if (!kindMatches(entry.second.id, funcIdKind)) + if (desc.funcIdKind != funcIdKind) { + LLVM_DEBUG(llvm::dbgs() << "\tskip, kind does not match\n"); continue; - + } // Ensure that the entry has return and arguments type that match the one - // provided. - if (retType != entry.second.outputTy || - argTypes.size() != entry.second.argTys.size()) + // requested. + if (desc.outputTy != retType) { + LLVM_DEBUG(llvm::dbgs() << "\tskip, return type does not match\n"); continue; - if (!std::equal(argTypes.begin(), argTypes.end(), - entry.second.argTys.begin())) + } + if (desc.argTys.size() != argTypes.size()) { + LLVM_DEBUG(llvm::dbgs() << "\tskip, number of arguments does not match\n"); + continue; + } + if (!std::equal(argTypes.begin(), argTypes.end(), desc.argTys.begin())) { + LLVM_DEBUG(llvm::dbgs() << "\tskip, arguments types do not match\n"); continue; + } - return entry.second.id; + return desc.funcId; } - llvm_unreachable("Unimplemented descriptor"); + llvm_unreachable("Could not find function id"); return SYCLFuncDescriptor::FuncId::Unknown; } @@ -146,92 +180,152 @@ SYCLFuncRegistry::SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder) converter.convertType(MemRefType::get(-1, IDType::get(context, 2))); Type id3PtrTy = converter.convertType(MemRefType::get(-1, IDType::get(context, 3))); + Type range1PtrTy = + converter.convertType(MemRefType::get(-1, RangeType::get(context, 1))); + Type range2PtrTy = + converter.convertType(MemRefType::get(-1, RangeType::get(context, 2))); + Type range3PtrTy = + converter.convertType(MemRefType::get(-1, RangeType::get(context, 3))); + auto voidTy = LLVM::LLVMVoidType::get(context); auto i64Ty = IntegerType::get(context, 64); - // Construct the SYCL functions descriptors (enum, - // function name, signature). + // Construct the SYCL functions descriptors for the sycl::id type. + // Descriptor format: (enum, function name, signature). // clang-format off - std::vector descriptors = { - // cl::sycl::id<1>::id() + std::vector idDescriptors = { + // sycl::id<1>::id() SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1CtorDefault, - "_ZN2cl4sycl2idILi1EEC2Ev", voidTy, {id1PtrTy}), - // cl::sycl::id<2>::id() + "_ZN2cl4sycl2idILi1EEC2Ev", voidTy, {id1PtrTy}), + // sycl::id<2>::id() SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2CtorDefault, - "_ZN2cl4sycl2idILi2EEC2Ev", voidTy, {id2PtrTy}), - // cl::sycl::id<3>::id() + "_ZN2cl4sycl2idILi2EEC2Ev", voidTy, {id2PtrTy}), + // sycl::id<3>::id() SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3CtorDefault, - "_ZN2cl4sycl2idILi3EEC2Ev", voidTy, {id3PtrTy}), + "_ZN2cl4sycl2idILi3EEC2Ev", voidTy, {id3PtrTy}), - // cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id1CtorSizeT, + // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1CtorSizeT, "_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE", voidTy, {id1PtrTy, i64Ty}), - // cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id2CtorSizeT, + // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2CtorSizeT, "_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE", voidTy, {id2PtrTy, i64Ty}), - // cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id3CtorSizeT, + // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3CtorSizeT, "_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE", voidTy, {id3PtrTy, i64Ty}), - // cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id1CtorRange, + // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1Ctor2SizeT, "_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm", voidTy, {id1PtrTy, i64Ty, i64Ty}), - // cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id2CtorRange, + // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2Ctor2SizeT, "_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm", voidTy, {id2PtrTy, i64Ty, i64Ty}), - // cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id3CtorRange, + // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3Ctor2SizeT, "_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm", voidTy, {id3PtrTy, i64Ty, i64Ty}), - // cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id1CtorItem, + // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1Ctor3SizeT, "_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm", voidTy, {id1PtrTy, i64Ty, i64Ty, i64Ty}), - // cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id2CtorItem, + // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2Ctor3SizeT, "_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm", voidTy, {id2PtrTy, i64Ty, i64Ty, i64Ty}), - // cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id3CtorItem, + // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3Ctor3SizeT, "_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm", voidTy, {id3PtrTy, i64Ty, i64Ty, i64Ty}), - // cl::sycl::id<1>::id(cl::sycl::id<1> const&) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id1CopyCtor, - "_ZN2cl4sycl2idILi1EEC1ERKS2_", - voidTy, {id1PtrTy, id1PtrTy}), - // cl::sycl::id<2>::id(cl::sycl::id<2> const&) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id2CopyCtor, - "_ZN2cl4sycl2idILi2EEC1ERKS2_", - voidTy, {id2PtrTy, id2PtrTy}), - // cl::sycl::id<3>::id(cl::sycl::id<3> const&) - SYCLFuncDescriptor( - SYCLFuncDescriptor::FuncId::Id3CopyCtor, - "_ZN2cl4sycl2idILi3EEC1ERKS2_", - voidTy, {id3PtrTy, id3PtrTy}), + // sycl::id<1>::id(sycl::id<1> const&) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1CopyCtor, + "_ZN2cl4sycl2idILi1EEC1ERKS2_", voidTy, {id1PtrTy, id1PtrTy}), + // sycl::id<2>::id(sycl::id<2> const&) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2CopyCtor, + "_ZN2cl4sycl2idILi2EEC1ERKS2_", voidTy, {id2PtrTy, id2PtrTy}), + // sycl::id<3>::id(sycl::id<3> const&) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3CopyCtor, + "_ZN2cl4sycl2idILi3EEC1ERKS2_", voidTy, {id3PtrTy, id3PtrTy}), + }; + + // Construct the SYCL functions descriptors for the sycl::range type. + std::vector rangeDescriptors = { + // sycl::range<1>::range() + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range1CtorDefault, + "_ZN2cl4sycl5rangeILi1EEC2Ev", voidTy, {range1PtrTy}), + // sycl::range<2>::range() + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range2CtorDefault, + "_ZN2cl4sycl5rangeILi2EEC2Ev", voidTy, {range2PtrTy}), + // sycl::range<3>::range() + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range3CtorDefault, + "_ZN2cl4sycl5rangeILi3EEC2Ev", voidTy, {range3PtrTy}), + + // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range1CtorSizeT, + "_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE", + voidTy, {range1PtrTy, i64Ty}), + // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range2CtorSizeT, + "_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE", + voidTy, {range2PtrTy, i64Ty}), + // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range3CtorSizeT, + "_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE", + voidTy, {range3PtrTy, i64Ty}), + + // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range1Ctor2SizeT, + "_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm", + voidTy, {range1PtrTy, i64Ty, i64Ty}), + // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range2Ctor2SizeT, + "_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm", + voidTy, {range2PtrTy, i64Ty, i64Ty}), + // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range3Ctor2SizeT, + "_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm", + voidTy, {range3PtrTy, i64Ty, i64Ty}), + + // sycl::range<1>::range<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range1Ctor3SizeT, + "_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm", + voidTy, {range1PtrTy, i64Ty, i64Ty, i64Ty}), + // sycl::range<2>::range<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range2Ctor3SizeT, + "_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm", + voidTy, {range2PtrTy, i64Ty, i64Ty, i64Ty}), + // sycl::range<3>::range<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range3Ctor3SizeT, + "_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm", + voidTy, {range3PtrTy, i64Ty, i64Ty, i64Ty}), + + // sycl::range<1>::range(sycl::range<1> const&) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range1CopyCtor, + "_ZN2cl4sycl5rangeILi1EEC1ERKS2_", voidTy, {range1PtrTy, range1PtrTy}), + // sycl::range<2>::range(sycl::range<2> const&) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range2CopyCtor, + "_ZN2cl4sycl5rangeILi2EEC1ERKS2_", voidTy, {range2PtrTy, range2PtrTy}), + // sycl::range<3>::range(sycl::range<3> const&) + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Range3CopyCtor, + "_ZN2cl4sycl5rangeILi3EEC1ERKS2_", voidTy, {range3PtrTy, range3PtrTy}), }; // clang-format on + // Concatenate all descriptors. + std::vector descriptors; + descriptors.reserve(idDescriptors.size() + rangeDescriptors.size()); + descriptors.insert(descriptors.end(), idDescriptors.begin(), idDescriptors.end()); + descriptors.insert(descriptors.end(), rangeDescriptors.begin(), rangeDescriptors.end()); + // Declare SYCL functions and add them to the registry. for (SYCLFuncDescriptor &funcDesc : descriptors) { funcDesc.declareFunction(module, builder); - registry.emplace(funcDesc.id, funcDesc); + registry.emplace(funcDesc.funcId, funcDesc); } } diff --git a/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp b/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp index e00e882d23caf..31dc7dc9b3297 100644 --- a/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp +++ b/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp @@ -202,36 +202,32 @@ class ConstructorPattern final LogicalResult matchAndRewrite(sycl::SYCLConstructorOp op, OpAdaptor opAdaptor, ConversionPatternRewriter &rewriter) const override { - StringRef typeStr = op.Type(); - if (typeStr == "id") - return rewriteIdConstructor(op, opAdaptor, rewriter); - - LLVM_DEBUG(llvm::dbgs() << "op: "; op.dump(); llvm::dbgs() << "\n"); - llvm_unreachable("Unhandled sycl.constructor type"); - - return failure(); + Twine name = op.Type() + "Ctor"; + return rewriteConstructor(SYCLFuncDescriptor::nameToFuncIdKind(name), op, + opAdaptor, rewriter); } - /// Rewrite sycl.constructor() { type = @id } to a LLVM call to the - /// appropriate constructor function for sycl::id. - LogicalResult - rewriteIdConstructor(SYCLConstructorOp op, OpAdaptor opAdaptor, - ConversionPatternRewriter &rewriter) const { - assert(op.Type() == "id" && "Unexpected sycl.constructor type"); +private: + /// Rewrite sycl.constructor() { type = * } to a LLVM call to the appropriate + /// constructor function. + LogicalResult rewriteConstructor(SYCLFuncDescriptor::FuncIdKind ctorKind, + SYCLConstructorOp op, OpAdaptor opAdaptor, + ConversionPatternRewriter &rewriter) const { + assert((ctorKind != SYCLFuncDescriptor::FuncIdKind::Unknown) && + "Unexpected ctorKind"); LLVM_DEBUG(llvm::dbgs() << "ConstructorPattern: Rewriting op: "; op.dump(); llvm::dbgs() << "\n"); ModuleOp module = op.getOperation()->getParentOfType(); - MLIRContext *context = module.getContext(); - - // Lookup the ctor function to use. const auto ®istry = SYCLFuncRegistry::create(module, rewriter); - auto voidTy = LLVM::LLVMVoidType::get(context); - SYCLFuncDescriptor::FuncId funcId = - registry.getFuncId(SYCLFuncDescriptor::FuncIdKind::IdCtor, voidTy, - opAdaptor.Args().getTypes()); - // Generate an LLVM call to the appropriate ctor. + /// Lookup the FuncId corresponding to the ctor function to use, which is + /// determined based on 'ctorKind) the kind of constructor to search for, and + /// the LLVM types of the sycl.constructor arguments. + SYCLFuncDescriptor::FuncId funcId = registry.getFuncId( + ctorKind, LLVM::LLVMVoidType::get(module.getContext()), + opAdaptor.Args().getTypes()); + SYCLFuncDescriptor::call(funcId, opAdaptor.getOperands(), registry, rewriter, op.getLoc()); diff --git a/mlir-sycl/test/Conversion/SYCLToLLVM/func-ops-to-llvm.mlir b/mlir-sycl/test/Conversion/SYCLToLLVM/func-ops-to-llvm.mlir index 2ddb1386474ba..5c848b4fdd11f 100644 --- a/mlir-sycl/test/Conversion/SYCLToLLVM/func-ops-to-llvm.mlir +++ b/mlir-sycl/test/Conversion/SYCLToLLVM/func-ops-to-llvm.mlir @@ -42,7 +42,6 @@ func.func @id1CtorSizeT(%arg0: memref>, %arg1: i64) { return } - // ----- // CHECK: llvm.func @_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64, %arg2: i64, %ar // ----- //===-------------------------------------------------------------------------------------------------===// -// Constructors sycl::id::id(sycl::id const&) +// Constructors sycl::id::id(sycl::id const&, sycl::id const&) //===-------------------------------------------------------------------------------------------------===// // CHECK: llvm.func @_ZN2cl4sycl2idILi1EEC1ERKS2_([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: memref } // ----- + +//===-------------------------------------------------------------------------------------------------===// +// Constructors for sycl::range::range() +//===-------------------------------------------------------------------------------------------------===// + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi1EEC2Ev([[THIS_PTR_TYPE:!llvm.struct<\(ptr>) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi1EEC2Ev({{.*}}) : ([[THIS_PTR_TYPE]]) -> () + sycl.constructor(%arg0) {Type = @range} : (memref>) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi2EEC2Ev([[THIS_PTR_TYPE:!llvm.struct<\(ptr>) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi2EEC2Ev({{.*}}) : ([[THIS_PTR_TYPE]]) -> () + sycl.constructor(%arg0) {Type = @range} : (memref>) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi3EEC2Ev([[THIS_PTR_TYPE:!llvm.struct<\(ptr>) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi3EEC2Ev({{.*}}) : ([[THIS_PTR_TYPE]]) -> () + sycl.constructor(%arg0) {Type = @range} : (memref>) -> () + return +} + +// ----- + +//===-------------------------------------------------------------------------------------------------===// +// Constructors for cl::sycl::range::range(std::enable_if<(n)==(n), unsigned long>::type) +//===-------------------------------------------------------------------------------------------------===// + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE({{.*}}, %arg5) : ([[THIS_PTR_TYPE]], i64) -> () + sycl.constructor(%arg0, %arg1) {Type = @range} : (memref>, i64) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE({{.*}}, %arg5) : ([[THIS_PTR_TYPE]], i64) -> () + sycl.constructor(%arg0, %arg1) {Type = @range} : (memref>, i64) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE({{.*}}, %arg5) : ([[THIS_PTR_TYPE]], i64) -> () + sycl.constructor(%arg0, %arg1) {Type = @range} : (memref>, i64) -> () + return +} + +// ----- + +//===-------------------------------------------------------------------------------------------------===// +// Constructors for cl::sycl::range::range(std::enable_if<(n)==(n), unsigned long>::type, unsigned long) +//===-------------------------------------------------------------------------------------------------===// + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64, %arg2: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm({{.*}}, %arg5, %arg6) : ([[THIS_PTR_TYPE]], i64, i64) -> () + sycl.constructor(%arg0, %arg1, %arg2) {Type = @range} : (memref>, i64, i64) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64, %arg2: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm({{.*}}, %arg5, %arg6) : ([[THIS_PTR_TYPE]], i64, i64) -> () + sycl.constructor(%arg0, %arg1, %arg2) {Type = @range} : (memref>, i64, i64) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64, %arg2: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm({{.*}}, %arg5, %arg6) : ([[THIS_PTR_TYPE]], i64, i64) -> () + sycl.constructor(%arg0, %arg1, %arg2) {Type = @range} : (memref>, i64, i64) -> () + return +} + +// ----- + +//===-------------------------------------------------------------------------------------------------===// +// Constructors for cl::sycl::range::range(std::enable_if<(n)==(n), unsigned long>::type, unsigned long, unsigned long) +//===-------------------------------------------------------------------------------------------------===// + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64, %arg2: i64, %arg3: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm({{.*}}, %arg5, %arg6, %arg7) : ([[THIS_PTR_TYPE]], i64, i64, i64) -> () + sycl.constructor(%arg0, %arg1, %arg2, %arg3) {Type = @range} : (memref>, i64, i64, i64) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64, %arg2: i64, %arg3: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm({{.*}}, %arg5, %arg6, %arg7) : ([[THIS_PTR_TYPE]], i64, i64, i64) -> () + sycl.constructor(%arg0, %arg1, %arg2, %arg3) {Type = @range} : (memref>, i64, i64, i64) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: i64, %arg2: i64, %arg3: i64) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm({{.*}}, %arg5, %arg6, %arg7) : ([[THIS_PTR_TYPE]], i64, i64, i64) -> () + sycl.constructor(%arg0, %arg1, %arg2, %arg3) {Type = @range} : (memref>, i64, i64, i64) -> () + return +} + +// ----- + +//===-------------------------------------------------------------------------------------------------===// +// Constructors sycl::range::id(sycl::range const&, sycl::range const&) +//===-------------------------------------------------------------------------------------------------===// + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi1EEC1ERKS2_([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: memref>) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi1EEC1ERKS2_({{.*}}, {{.*}}) : ([[THIS_PTR_TYPE]], [[THIS_PTR_TYPE]]) -> () + "sycl.constructor"(%arg0, %arg1) {Type = @range} : (memref>, memref>) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi2EEC1ERKS2_([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: memref>) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi2EEC1ERKS2_({{.*}}, {{.*}}) : ([[THIS_PTR_TYPE]], [[THIS_PTR_TYPE]]) -> () + "sycl.constructor"(%arg0, %arg1) {Type = @range} : (memref>, memref>) -> () + return +} + +// ----- + +// CHECK: llvm.func @_ZN2cl4sycl5rangeILi3EEC1ERKS2_([[THIS_PTR_TYPE:!llvm.struct<\(ptr>, %arg1: memref>) { + // CHECK: llvm.call @_ZN2cl4sycl5rangeILi3EEC1ERKS2_({{.*}}, {{.*}}) : ([[THIS_PTR_TYPE]], [[THIS_PTR_TYPE]]) -> () + "sycl.constructor"(%arg0, %arg1) {Type = @range} : (memref>, memref>) -> () + return +}