Skip to content

Commit

Permalink
Add type converters from sycl::AccessorImplDeviceType and sycl::Acces…
Browse files Browse the repository at this point in the history
…sorType to llvm types (#38)

The runtime class of `sycl::AccessorImplDeviceType`:
```
template <int Dims> class AccessorImplDevice {
...
  id<Dims> Offset;
  range<Dims> AccessRange;
  range<Dims> MemRange;
...
};
```
The runtime class of `sycl::AccessorType`:
```
template <typename DataT, int Dimensions, access::mode AccessMode,
          access::target AccessTarget, access::placeholder IsPlaceholder,
          typename PropertyListT>
class __SYCL_SPECIAL_CLASS accessor :
    public detail::accessor_common<DataT, Dimensions, AccessMode, AccessTarget,
                                   IsPlaceholder, PropertyListT> {
...
  detail::AccessorImplDevice<AdjustedDim> impl;
  union {
    ConcreteASPtrType MData;
  };
...
};
```

Example of LLVM IR generated directly from clang:
```
%"class.cl::sycl::accessor" = type { %"class.cl::sycl::detail::AccessorImplDevice", %union.anon }
%"class.cl::sycl::detail::AccessorImplDevice" = type { %"class.cl::sycl::id", %"class.cl::sycl::range", %"class.cl::sycl::range" }
%union.anon = type { i32 addrspace(1)* }
```Signed-off-by: Tsang, Whitney <[email protected]>
  • Loading branch information
whitneywhtsang authored and etiotto committed Sep 6, 2022
1 parent 502d3b9 commit 6f4d2dc
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 34 deletions.
128 changes: 100 additions & 28 deletions mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,39 +30,88 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//

// Get the LLVM type of "class.cl::sycl::detail::array" with number of
// dimentions \p dimNum and element type \p type.
static Type getSYCLArrayTy(MLIRContext &context, unsigned dimNum, Type type) {
// Get LLVM type of "class.cl::sycl::detail::array" with \p dimNum number of
// dimensions and element type \p type.
static Optional<Type> getArrayTy(MLIRContext &context, unsigned dimNum,
Type type) {
assert((dimNum == 1 || dimNum == 2 || dimNum == 3) &&
"Expecting number of dimensions to be 1, 2, or 3.");
auto structTy = LLVM::LLVMStructType::getIdentified(
&context, "class.cl::sycl::detail::array." + std::to_string(dimNum));
if (!structTy.isInitialized()) {
auto arrayTy = LLVM::LLVMArrayType::get(type, dimNum);
auto res = structTy.setBody({arrayTy}, /*isPacked=*/false);
assert(succeeded(res) &&
"Unexpected failure from LLVMStructType::setBody.");
if (failed(structTy.setBody(arrayTy, /*isPacked=*/false)))
return llvm::None;
}
return structTy;
}

// Get the LLVM type of a SYCL range or id type, given \p type - the type in
// SYCL, \p name - the expected LLVM type name, \p converter - LLVM type
// converter.
template <typename T>
static Type getSYCLRangeOrIDTy(T type, StringRef name,
LLVMTypeConverter &converter) {
unsigned dimNum = type.getDimension();
auto structTy = LLVM::LLVMStructType::getIdentified(
//===----------------------------------------------------------------------===//
// Type conversion
//===----------------------------------------------------------------------===//

/// Converts SYCL range or id type to LLVM type, given \p dimNum - number of
/// dimensions, \p name - the expected LLVM type name, \p converter - LLVM type
/// converter.
static Optional<Type> convertRangeOrIDTy(unsigned dimNum, StringRef name,
LLVMTypeConverter &converter) {
auto convertedTy = LLVM::LLVMStructType::getIdentified(
&converter.getContext(), name.str() + "." + std::to_string(dimNum));
if (!structTy.isInitialized()) {
auto res = structTy.setBody(getSYCLArrayTy(converter.getContext(), dimNum,
converter.getIndexType()),
/*isPacked=*/false);
assert(succeeded(res) &&
"Unexpected failure from LLVMStructType::setBody.");
if (!convertedTy.isInitialized()) {
auto arrayTy =
getArrayTy(converter.getContext(), dimNum, converter.getIndexType());
if (!arrayTy.hasValue())
return llvm::None;
if (failed(convertedTy.setBody(arrayTy.getValue(), /*isPacked=*/false)))
return llvm::None;
}
return LLVM::LLVMPointerType::get(structTy);
return convertedTy;
}

/// Converts SYCL id type to LLVM type.
static Optional<Type> convertIDType(sycl::IDType type,
LLVMTypeConverter &converter) {
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::id",
converter);
}

/// Converts SYCL range type to LLVM type.
static Optional<Type> convertRangeType(sycl::RangeType type,
LLVMTypeConverter &converter) {
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::range",
converter);
}

/// Converts SYCL accessor implement device type to LLVM type.
static Optional<Type>
convertAccessorImplDeviceType(sycl::AccessorImplDeviceType type,
LLVMTypeConverter &converter) {
SmallVector<Type> convertedElemTypes;
convertedElemTypes.reserve(type.getBody().size());
if (failed(converter.convertTypes(type.getBody(), convertedElemTypes)))
return llvm::None;

return LLVM::LLVMStructType::getNewIdentified(
&converter.getContext(), "class.cl::sycl::detail::AccessorImplDevice",
convertedElemTypes, /*isPacked=*/false);
}

/// Converts SYCL accessor type to LLVM type.
static Optional<Type> convertAccessorType(sycl::AccessorType type,
LLVMTypeConverter &converter) {
SmallVector<Type> convertedElemTypes;
convertedElemTypes.reserve(type.getBody().size());
if (failed(converter.convertTypes(type.getBody(), convertedElemTypes)))
return llvm::None;

auto ptrTy = LLVM::LLVMPointerType::get(type.getType(), /*addressSpace=*/1);
auto structTy =
LLVM::LLVMStructType::getLiteral(&converter.getContext(), ptrTy);
convertedElemTypes.push_back(structTy);

return LLVM::LLVMStructType::getNewIdentified(
&converter.getContext(), "class.cl::sycl::accessor", convertedElemTypes,
/*isPacked=*/false);
}

//===----------------------------------------------------------------------===//
Expand All @@ -71,17 +120,40 @@ static Type getSYCLRangeOrIDTy(T type, StringRef name,

void mlir::sycl::populateSYCLToLLVMTypeConversion(
LLVMTypeConverter &typeConverter) {
typeConverter.addConversion([&](mlir::sycl::IDType type) {
return getSYCLRangeOrIDTy<mlir::sycl::IDType>(type, "class.cl::sycl::id",
typeConverter);
typeConverter.addConversion([&](sycl::AccessorImplDeviceType type) {
return convertAccessorImplDeviceType(type, typeConverter);
});
typeConverter.addConversion([&](sycl::AccessorType type) {
return convertAccessorType(type, typeConverter);
});
typeConverter.addConversion([&](sycl::ArrayType type) {
llvm_unreachable("SYCLToLLVM - sycl::ArrayType not handle (yet)");
return llvm::None;
});
typeConverter.addConversion([&](sycl::GroupType type) {
llvm_unreachable("SYCLToLLVM - sycl::GroupType not handle (yet)");
return llvm::None;
});
typeConverter.addConversion(
[&](sycl::IDType type) { return convertIDType(type, typeConverter); });
typeConverter.addConversion([&](sycl::ItemBaseType type) {
llvm_unreachable("SYCLToLLVM - sycl::ItemBaseType not handle (yet)");
return llvm::None;
});
typeConverter.addConversion([&](sycl::ItemType type) {
llvm_unreachable("SYCLToLLVM - sycl::ItemType not handle (yet)");
return llvm::None;
});
typeConverter.addConversion([&](sycl::NdItemType type) {
llvm_unreachable("SYCLToLLVM - sycl::NdItemType not handle (yet)");
return llvm::None;
});
typeConverter.addConversion([&](mlir::sycl::RangeType type) {
return getSYCLRangeOrIDTy<mlir::sycl::RangeType>(
type, "class.cl::sycl::range", typeConverter);
typeConverter.addConversion([&](sycl::RangeType type) {
return convertRangeType(type, typeConverter);
});
}

void mlir::sycl::populateSYCLToLLVMConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
populateSYCLToLLVMTypeConversion(typeConverter);
populateSYCLToLLVMTypeConversion(typeConverter);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@

// CHECK-DAG: [[ARRAY_1:.*]] = type { [1 x i64] }
// CHECK-DAG: [[ARRAY_2:.*]] = type { [2 x i64] }
// CHECK-DAG: [[ID:.*]] = type { [[ARRAY_1]] }
// CHECK-DAG: [[RANGE_1:.*]] = type { [[ARRAY_1]] }
// CHECK-DAG: [[RANGE_2:.*]] = type { [[ARRAY_2]] }
// CHECK: define void @test_id([[ID]]* %0, [[ID]]* %1)
// CHECK: define void @test_range.1([[RANGE_1]]* %0)
// CHECK: define void @test_range.2([[RANGE_2]]* %0)
// CHECK-DAG: [[ID_1:%"class.cl::sycl::id.*]] = type { [[ARRAY_1]] }
// CHECK-DAG: [[ID_2:%"class.cl::sycl::id.*]] = type { [[ARRAY_2]] }
// CHECK-DAG: [[RANGE_1:%"class.cl::sycl::range.*]] = type { [[ARRAY_1]] }
// CHECK-DAG: [[RANGE_2:%"class.cl::sycl::range.*]] = type { [[ARRAY_2]] }
// CHECK-DAG: [[ACCESSORIMPLDEVICE_1:%"class.cl::sycl::detail::AccessorImplDevice.*]] = type { [[ID_1]], [[RANGE_1]], [[RANGE_1]] }
// CHECK-DAG: [[ACCESSORIMPLDEVICE_2:%"class.cl::sycl::detail::AccessorImplDevice.*]] = type { [[ID_2]], [[RANGE_2]], [[RANGE_2]] }
// CHECK-DAG: [[ACCESSOR_1:%"class.cl::sycl::accessor.*]] = type { [[ACCESSORIMPLDEVICE_1]], { i32 addrspace(1)* } }
// CHECK-DAG: [[ACCESSOR_2:%"class.cl::sycl::accessor.*]] = type { [[ACCESSORIMPLDEVICE_2]], { i64 addrspace(1)* } }
// CHECK: define void @test_id([[ID_1]] %0, [[ID_1]] %1)
// CHECK: define void @test_range.1([[RANGE_1]] %0)
// CHECK: define void @test_range.2([[RANGE_2]] %0)
// CHECK: define void @test_accessorImplDevice([[ACCESSORIMPLDEVICE_1]] %0)
// CEHCK: define void @test_accessor.1([[ACCESSOR_1]] %0)
// CEHCK: define void @test_accessor.2([[ACCESSOR_2]] %0)

module {
func.func @test_id(%arg0: !sycl.id<1>, %arg1: !sycl.id<1>) {
Expand All @@ -29,4 +37,13 @@ module {
func.func @test_range.2(%arg0: !sycl.range<2>) {
return
}
func.func @test_accessorImplDevice(%arg0: !sycl.accessor_impl_device<[1], (!sycl.id<1>, !sycl.range<1>, !sycl.range<1>)>) {
return
}
func.func @test_accessor.1(%arg0: !sycl.accessor<[1, i32, write, global_buffer], (!sycl.accessor_impl_device<[1], (!sycl.id<1>, !sycl.range<1>, !sycl.range<1>)>)>) {
return
}
func.func @test_accessor.2(%arg0: !sycl.accessor<[2, i64, write, global_buffer], (!sycl.accessor_impl_device<[2], (!sycl.id<2>, !sycl.range<2>, !sycl.range<2>)>)>) {
return
}
}

0 comments on commit 6f4d2dc

Please sign in to comment.