From 4509d48692296cc5ef7e69b924870cbd72fc252f Mon Sep 17 00:00:00 2001 From: Eric Schweitz Date: Tue, 26 Jan 2021 10:18:44 -0800 Subject: [PATCH] [mlir] sret and byval now require a type argument when constructed. Fixes the LLVM code gen bugs and adds the missing tests. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D95378 --- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 6 ++++-- mlir/test/Target/llvmir.mlir | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 87ec35cc1c42..f9b714738ffe 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1105,7 +1105,8 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { if (!argTy.isa()) return func.emitError( "llvm.sret attribute attached to LLVM non-pointer argument"); - llvmArg.addAttr(llvm::Attribute::AttrKind::StructRet); + llvmArg.addAttrs(llvm::AttrBuilder().addStructRetAttr( + llvmArg.getType()->getPointerElementType())); } if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { @@ -1113,7 +1114,8 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { if (!argTy.isa()) return func.emitError( "llvm.byval attribute attached to LLVM non-pointer argument"); - llvmArg.addAttr(llvm::Attribute::AttrKind::ByVal); + llvmArg.addAttrs(llvm::AttrBuilder().addByValAttr( + llvmArg.getType()->getPointerElementType())); } valueMapping[mlirArg] = &llvmArg; diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir index 4645ef96a9d0..05c309bb6f0d 100644 --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -939,6 +939,16 @@ llvm.func @llvm_noalias(%arg0: !llvm.ptr {llvm.noalias = true}) { llvm.return } +// CHECK-LABEL: define void @byvalattr(i32* byval(i32) % +llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval}) { + llvm.return +} + +// CHECK-LABEL: define void @sretattr(i32* sret(i32) % +llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret}) { + llvm.return +} + // CHECK-LABEL: define void @llvm_align(float* align 4 {{%*.}}) llvm.func @llvm_align(%arg0: !llvm.ptr {llvm.align = 4}) { llvm.return