-
Notifications
You must be signed in to change notification settings - Fork 329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the deallocation problem for sequence op #1673
Changes from all commits
67af465
552bf40
59a839f
a545c8a
74cf004
d339d14
18b13d0
6b9b669
d84db59
5762103
c655be5
a4c39b1
6ae4b52
5c4f172
0209994
ba6ab93
234dc33
3f41fdc
890cacb
3025192
1151d09
709e8be
afd03e5
61583e8
2375ff0
f1189a0
ba8fb36
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -347,8 +347,9 @@ void FrontendToKrnlLoweringPass::runOnOperation() { | |
|
||
// If `emitDealloc` is turned off, make sure we don't have buffer deallocation | ||
// at this level. Will use MLIR buffer-deallocation for this purpose instead. | ||
if (!emitDealloc) | ||
target.addIllegalOp<mlir::memref::DeallocOp>(); | ||
// However, since the SequenceErase needs to emit memref dealloc, the previous | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not for now, but I think we are pretty sure by now not to use the old buffer scheme. At sometimes, we could do a cleanup of the code, and eliminating all that is not needed anymore. Definitely not for this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR needs to comment out the code because dealloc will be generated by SequenceErase. |
||
// the following statement is commented out (Chentong) | ||
// if (!emitDealloc) target.addIllegalOp<mlir::memref::DeallocOp>(); | ||
|
||
// TODO: enable this once more ops are supported. | ||
// We also define the ONNX dialect as Illegal so that the conversion will fail | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,17 +31,21 @@ struct ONNXSequenceAtOpLowering : public ConversionPattern { | |
MultiDialectBuilder<KrnlBuilder, MemRefBuilder> create(rewriter, loc); | ||
IndexExprScope IEScope(&rewriter, loc); | ||
|
||
auto input_sequence = operandAdaptor.input_sequence(); | ||
Value input_sequence = operandAdaptor.input_sequence(); | ||
Type outputMemRefType = | ||
input_sequence.getType().cast<MemRefType>().getElementType(); | ||
auto dimSize = create.mem.dim(input_sequence, 0); | ||
SymbolIndexExpr boundIE(dimSize); | ||
IndexExpr positionIE = | ||
SymbolIndexExpr(create.krnl.load(operandAdaptor.position())); | ||
|
||
// Just for lit test to pass | ||
auto correctionIE = positionIE + boundIE; | ||
positionIE = IndexExpr::select(positionIE < 0, correctionIE, positionIE); | ||
auto outputVal = create.krnl.load( | ||
operandAdaptor.input_sequence(), positionIE.getValue()); | ||
// Handle the negative position | ||
IndexExpr condIE = positionIE < 0; | ||
IndexExpr fixedPosition = positionIE + boundIE; | ||
positionIE = IndexExpr::select(condIE, fixedPosition, positionIE); | ||
|
||
Value outputVal = rewriter.create<KrnlSeqExtractOp>(loc, outputMemRefType, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you created a method to add this op, use it here too with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a big fan of the dialect builder, especially when there is nothing abstracted out by builder. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well we can abstract the handling of the int attribute, which I personally don't like to remember this awkward sequence... but if you don't want to do it, it is fine. |
||
input_sequence, positionIE.getValue(), | ||
IntegerAttr::get(rewriter.getIntegerType(1, false), 1)); | ||
|
||
rewriter.replaceOp(op, outputVal); | ||
return success(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ struct ONNXSequenceEraseOpLowering : public ConversionPattern { | |
ConversionPatternRewriter &rewriter) const final { | ||
Location loc = op->getLoc(); | ||
ONNXSequenceEraseOpAdaptor operandAdaptor(operands); | ||
ONNXSequenceInsertOp thisOp = dyn_cast<ONNXSequenceInsertOp>(op); | ||
ONNXSequenceEraseOp thisOp = dyn_cast<ONNXSequenceEraseOp>(op); | ||
MultiDialectBuilder<MathBuilder, MemRefBuilder> create(rewriter, loc); | ||
IndexExprScope IEScope(&rewriter, loc); | ||
|
||
|
@@ -66,7 +66,7 @@ struct ONNXSequenceEraseOpLowering : public ConversionPattern { | |
positionIE = IndexExpr::select(positionIE < 0, correctionIE, positionIE); | ||
} | ||
|
||
// Copy before the insert | ||
// Copy the elements before the position | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great usage of consistent language, big thanks |
||
KrnlBuilder createKrnl(rewriter, loc); | ||
SmallVector<IndexExpr, 1> lbs; | ||
lbs.emplace_back(LiteralIndexExpr(0)); | ||
|
@@ -80,9 +80,12 @@ struct ONNXSequenceEraseOpLowering : public ConversionPattern { | |
createKrnl.store(element, alloc, indicesLoopInd[0]); | ||
}); | ||
|
||
// ToDo (chentong)Free the erased element | ||
// Free the element to be erased | ||
Value element = | ||
createKrnl.load(operandAdaptor.input_sequence(), positionIE.getValue()); | ||
create.mem.dealloc(element); | ||
|
||
// Copy after the insert | ||
// Copy the elements after the position | ||
SmallVector<IndexExpr, 1> lbs1; | ||
lbs1.emplace_back(positionIE + 1); | ||
SmallVector<IndexExpr, 1> ubs1; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" | ||
#include "src/Dialect/Krnl/KrnlHelper.hpp" | ||
#include "src/Dialect/Krnl/KrnlOps.hpp" | ||
#include "src/Dialect/Mlir/DialectBuilder.hpp" | ||
#include "llvm/Support/Debug.h" | ||
|
||
#define DEBUG_TYPE "krnl_to_llvm" | ||
|
@@ -40,7 +41,9 @@ class KrnlSeqExtractOpLowering : public ConversionPattern { | |
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const override { | ||
KrnlSeqExtractOpAdaptor operandAdaptor(operands); | ||
KrnlSeqExtractOp thisOp = dyn_cast<KrnlSeqExtractOp>(op); | ||
auto loc = op->getLoc(); | ||
MultiDialectBuilder<MathBuilder, MemRefBuilder> create(rewriter, loc); | ||
|
||
auto output = rewriter | ||
.create<memref::LoadOp>( | ||
|
@@ -51,8 +54,25 @@ class KrnlSeqExtractOpLowering : public ConversionPattern { | |
// if the element is read from seq after extracted, or deep deallocation | ||
// is added when seq is freed | ||
|
||
rewriter.replaceOp(op, output); | ||
return success(); | ||
if (thisOp.copy() == 0) { | ||
rewriter.replaceOp(op, output); | ||
return success(); | ||
} else { | ||
if (!output.getType().isa<MemRefType>()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this feels like an assert :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In future, we will handle the non MemRefType case. I usually use unreachable for that purpose: not a semantic error but an implementation limitation. Not sure about the convention in llvm. |
||
llvm_unreachable( | ||
"Not implemented: type of onnx seq element is not tensor"); | ||
auto outputType = output.getType().cast<MemRefType>(); | ||
SmallVector<mlir::Value, 4> allocParams; | ||
for (size_t i = 0; i < outputType.getShape().size(); i++) { | ||
if (outputType.getShape()[i] == -1) { | ||
allocParams.emplace_back(create.mem.dim(output, i)); | ||
} | ||
} | ||
Value alloc = create.mem.alignedAlloc(outputType, allocParams); | ||
rewriter.create<memref::CopyOp>(loc, output, alloc); | ||
rewriter.replaceOp(op, alloc); | ||
return success(); | ||
} | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could add a createKrnl method to generate a KrnlSeqExtract op where you can then skip the
loc
, and could accept an int for the attribute (and hide theIntegerAttr::get(rewriter.getIntegerType(1, false), 0))
inside that method.just make the code a bit cleaner, obviously not mandatory but it would be nice.