Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the deallocation problem for sequence op #1673

Merged
merged 27 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
67af465
compiled and run
chentong319 Aug 30, 2022
552bf40
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Aug 30, 2022
59a839f
restore CompilerPasses.cpp
chentong319 Aug 31, 2022
a545c8a
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Aug 31, 2022
74cf004
handle copy
chentong319 Sep 2, 2022
d339d14
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Sep 2, 2022
18b13d0
use attr
chentong319 Sep 3, 2022
6b9b669
lit test
chentong319 Sep 4, 2022
d84db59
dealloc test
chentong319 Sep 4, 2022
5762103
doc
chentong319 Sep 6, 2022
c655be5
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Sep 6, 2022
a4c39b1
format
chentong319 Sep 6, 2022
6ae4b52
ordering
chentong319 Sep 6, 2022
5c4f172
format
chentong319 Sep 6, 2022
0209994
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Sep 6, 2022
ba6ab93
change attr
chentong319 Sep 7, 2022
234dc33
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Sep 7, 2022
3f41fdc
doc
chentong319 Sep 8, 2022
890cacb
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Sep 8, 2022
3025192
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Sep 13, 2022
1151d09
fix erase
chentong319 Sep 13, 2022
709e8be
doc
chentong319 Sep 13, 2022
afd03e5
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Sep 13, 2022
61583e8
format
chentong319 Sep 13, 2022
2375ff0
comment
chentong319 Sep 13, 2022
f1189a0
response
chentong319 Sep 15, 2022
ba8fb36
Merge remote-tracking branch 'upstream/main' into seq-krnl-2
chentong319 Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ void addKrnlToLLVMPasses(
}

pm.addNestedPass<func::FuncOp>(krnl::createLowerKrnlRegionPass());
pm.addNestedPass<func::FuncOp>(krnl::createConvertSeqToMemrefPass());
pm.addNestedPass<func::FuncOp>(mlir::createConvertSCFToCFPass());

pm.addNestedPass<func::FuncOp>(krnl::createConvertSeqToMemrefPass());
pm.addPass(krnl::createConvertKrnlToLLVMPass(verifyInputTensors));
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
pm.addPass(mlir::createCanonicalizerPass());
Expand Down
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ struct ONNXLoopOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(
&regionOp.bodyRegion().front());
Value origIV = loopInd[0];
auto src = rewriter.create<KrnlSeqExtractOp>(
loc, seqElementType, output, origIV);
auto src = rewriter.create<KrnlSeqExtractOp>(loc,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could add a createKrnl method to generate a KrnlSeqExtract op where you can then skip the loc, and could accept an int for the attribute (and hide the IntegerAttr::get(rewriter.getIntegerType(1, false), 0)) inside that method.

just make the code a bit cleaner, obviously not mandatory but it would be nice.

seqElementType, output, origIV,
IntegerAttr::get(rewriter.getIntegerType(1, false), 0));
emitCopy(rewriter, loc, src, alloc, {origIV});
});
newOutputs.emplace_back(alloc);
Expand Down
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,9 @@ void FrontendToKrnlLoweringPass::runOnOperation() {

// If `emitDealloc` is turned off, make sure we don't have buffer deallocation
// at this level. Will use MLIR buffer-deallocation for this purpose instead.
if (!emitDealloc)
target.addIllegalOp<mlir::memref::DeallocOp>();
// However, since the SequenceErase needs to emit memref dealloc, the previous
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for now, but I think we are pretty sure by now not to use the old buffer scheme. At sometimes, we could do a cleanup of the code, and eliminating all that is not needed anymore. Definitely not for this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR needs to comment out the code because dealloc will be generated by SequenceErase.

// the following statement is commented out (Chentong)
// if (!emitDealloc) target.addIllegalOp<mlir::memref::DeallocOp>();

// TODO: enable this once more ops are supported.
// We also define the ONNX dialect as Illegal so that the conversion will fail
Expand Down
18 changes: 11 additions & 7 deletions src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,21 @@ struct ONNXSequenceAtOpLowering : public ConversionPattern {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder> create(rewriter, loc);
IndexExprScope IEScope(&rewriter, loc);

auto input_sequence = operandAdaptor.input_sequence();
Value input_sequence = operandAdaptor.input_sequence();
Type outputMemRefType =
input_sequence.getType().cast<MemRefType>().getElementType();
auto dimSize = create.mem.dim(input_sequence, 0);
SymbolIndexExpr boundIE(dimSize);
IndexExpr positionIE =
SymbolIndexExpr(create.krnl.load(operandAdaptor.position()));

// Just for lit test to pass
auto correctionIE = positionIE + boundIE;
positionIE = IndexExpr::select(positionIE < 0, correctionIE, positionIE);
auto outputVal = create.krnl.load(
operandAdaptor.input_sequence(), positionIE.getValue());
// Handle the negative position
IndexExpr condIE = positionIE < 0;
IndexExpr fixedPosition = positionIE + boundIE;
positionIE = IndexExpr::select(condIE, fixedPosition, positionIE);

Value outputVal = rewriter.create<KrnlSeqExtractOp>(loc, outputMemRefType,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you created a method to add this op, use it here too with create.krnl

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of the dialect builder, especially when there is nothing abstracted out by builder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well we can abstract the handling of the int attribute, which I personally don't like to remember this awkward sequence... but if you don't want to do it, it is fine.

input_sequence, positionIE.getValue(),
IntegerAttr::get(rewriter.getIntegerType(1, false), 1));

rewriter.replaceOp(op, outputVal);
return success();
Expand Down
11 changes: 7 additions & 4 deletions src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct ONNXSequenceEraseOpLowering : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ONNXSequenceEraseOpAdaptor operandAdaptor(operands);
ONNXSequenceInsertOp thisOp = dyn_cast<ONNXSequenceInsertOp>(op);
ONNXSequenceEraseOp thisOp = dyn_cast<ONNXSequenceEraseOp>(op);
MultiDialectBuilder<MathBuilder, MemRefBuilder> create(rewriter, loc);
IndexExprScope IEScope(&rewriter, loc);

Expand Down Expand Up @@ -66,7 +66,7 @@ struct ONNXSequenceEraseOpLowering : public ConversionPattern {
positionIE = IndexExpr::select(positionIE < 0, correctionIE, positionIE);
}

// Copy before the insert
// Copy the elements before the position
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great usage of consistent language, big thanks

KrnlBuilder createKrnl(rewriter, loc);
SmallVector<IndexExpr, 1> lbs;
lbs.emplace_back(LiteralIndexExpr(0));
Expand All @@ -80,9 +80,12 @@ struct ONNXSequenceEraseOpLowering : public ConversionPattern {
createKrnl.store(element, alloc, indicesLoopInd[0]);
});

// ToDo (chentong)Free the erased element
// Free the element to be erased
Value element =
createKrnl.load(operandAdaptor.input_sequence(), positionIE.getValue());
create.mem.dealloc(element);

// Copy after the insert
// Copy the elements after the position
SmallVector<IndexExpr, 1> lbs1;
lbs1.emplace_back(positionIE + 1);
SmallVector<IndexExpr, 1> ubs1;
Expand Down
55 changes: 7 additions & 48 deletions src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,11 @@ struct ONNXSequenceInsertOpLowering : public ConversionPattern {
"Failed to convert type to MemRefType");
MemRefType outputMemRefType = convertedType.cast<MemRefType>();

auto seqElementConvertedType =
outputMemRefType.getElementType().cast<MemRefType>();
auto input_sequence = operandAdaptor.input_sequence();
auto dimSize = create.mem.dim(input_sequence, 0);
SymbolIndexExpr boundIE(dimSize);

// Output sequence has one more element
auto outputBound = boundIE + 1;
SmallVector<IndexExpr, 1> ubsIE;
ubsIE.emplace_back(outputBound);
Value alloc =
insertAllocAndDeallocSimple(rewriter, op, outputMemRefType, loc, ubsIE);

// Fill the output sequence

// Handle Optional and negative position
IndexExpr positionIE;
if (isFromNone(operandAdaptor.position())) {
// Insert at the end of the sequence
Expand All @@ -65,45 +55,14 @@ struct ONNXSequenceInsertOpLowering : public ConversionPattern {
} else {
positionIE = SymbolIndexExpr(create.krnl.load(operandAdaptor.position()));
// Handle the negative position
positionIE =
IndexExpr::select(positionIE < 0, positionIE + boundIE, positionIE);
IndexExpr condIE = positionIE < 0;
IndexExpr fixedPosition = positionIE + boundIE;
positionIE = IndexExpr::select(condIE, fixedPosition, positionIE);
}

// Copy elements before the insertion position
SmallVector<IndexExpr, 1> lbs;
lbs.emplace_back(LiteralIndexExpr(0));
SmallVector<IndexExpr, 1> ubs;
ubs.emplace_back(positionIE);
ValueRange firstLoopDef = create.krnl.defineLoops(1);
create.krnl.iterateIE(firstLoopDef, firstLoopDef, lbs, ubs,
[&](KrnlBuilder createKrnl, ValueRange indicesLoopInd) {
auto element = createKrnl.load(
operandAdaptor.input_sequence(), indicesLoopInd[0]);
auto converted = create.mem.cast(element, seqElementConvertedType);
createKrnl.store(converted, alloc, indicesLoopInd[0]);
});

// Insert the input tensor
// ToDo (chentong): need to duplicate the tensor
auto element =
create.mem.cast(operandAdaptor.tensor(), seqElementConvertedType);
create.krnl.store(element, alloc, positionIE.getValue());

// Copy elements after the insertion position
SmallVector<IndexExpr, 1> lbs1;
lbs1.emplace_back(positionIE + 1);
SmallVector<IndexExpr, 1> ubs1;
ubs1.emplace_back(outputBound);
ValueRange secondLoopDef = create.krnl.defineLoops(1);
create.krnl.iterateIE(secondLoopDef, secondLoopDef, lbs1, ubs1,
[&](KrnlBuilder createKrnl, ValueRange indicesLoopInd) {
auto element = createKrnl.load(
operandAdaptor.input_sequence(), indicesLoopInd[0]);
auto converted = create.mem.cast(element, seqElementConvertedType);
auto outputIndex =
create.math.add(indicesLoopInd[0], create.math.constantIndex(1));
createKrnl.store(converted, alloc, outputIndex);
});
Value alloc = rewriter.create<KrnlSeqInsertOp>(loc, outputMemRefType,
operandAdaptor.tensor(), operandAdaptor.input_sequence(),
positionIE.getValue());

rewriter.replaceOp(op, alloc);
return success();
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/SeqToMemref/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
add_onnx_mlir_library(OMSeqToMemref
ConvertSeqToMemref.cpp
KrnlSeqExtract.cpp
KrnlSeqInsert.cpp
KrnlSeqStore.cpp

LINK_LIBS PUBLIC
Expand Down
4 changes: 3 additions & 1 deletion src/Conversion/SeqToMemref/ConvertSeqToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ void ConvertSeqToMemrefPass::runOnOperation() {
ConversionTarget target(getContext());

target.addIllegalOp<KrnlSeqExtractOp>();
target.addIllegalOp<KrnlSeqInsertOp>();
target.addIllegalOp<KrnlSeqStoreOp>();
target.addLegalDialect<mlir::AffineDialect, mlir::arith::ArithmeticDialect,
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
mlir::vector::VectorDialect>();
mlir::vector::VectorDialect, mlir::scf::SCFDialect>();

// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the frontend operations.
Expand All @@ -71,6 +72,7 @@ void ConvertSeqToMemrefPass::runOnOperation() {
// Define patterns.
KrnlTypeConverter typeConverter;
populateLoweringKrnlSeqExtractOpPattern(typeConverter, patterns, ctx);
populateLoweringKrnlSeqInsertOpPattern(typeConverter, patterns, ctx);
populateLoweringKrnlSeqStoreOpPattern(typeConverter, patterns, ctx);

if (failed(applyPartialConversion(
Expand Down
3 changes: 3 additions & 0 deletions src/Conversion/SeqToMemref/ConvertSeqToMemref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ namespace krnl {
void populateLoweringKrnlSeqExtractOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);

void populateLoweringKrnlSeqInsertOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);

void populateLoweringKrnlSeqStoreOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);

Expand Down
24 changes: 22 additions & 2 deletions src/Conversion/SeqToMemref/KrnlSeqExtract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "krnl_to_llvm"
Expand All @@ -40,7 +41,9 @@ class KrnlSeqExtractOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
KrnlSeqExtractOpAdaptor operandAdaptor(operands);
KrnlSeqExtractOp thisOp = dyn_cast<KrnlSeqExtractOp>(op);
auto loc = op->getLoc();
MultiDialectBuilder<MathBuilder, MemRefBuilder> create(rewriter, loc);

auto output = rewriter
.create<memref::LoadOp>(
Expand All @@ -51,8 +54,25 @@ class KrnlSeqExtractOpLowering : public ConversionPattern {
// if the element is read from seq after extracted, or deep deallocation
// is added when seq is freed

rewriter.replaceOp(op, output);
return success();
if (thisOp.copy() == 0) {
rewriter.replaceOp(op, output);
return success();
} else {
if (!output.getType().isa<MemRefType>())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like an assert :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In future, we will handle the non MemRefType case. I usually use unreachable for that purpose: not a semantic error but an implementation limitation. Not sure about the convention in llvm.

llvm_unreachable(
"Not implemented: type of onnx seq element is not tensor");
auto outputType = output.getType().cast<MemRefType>();
SmallVector<mlir::Value, 4> allocParams;
for (size_t i = 0; i < outputType.getShape().size(); i++) {
if (outputType.getShape()[i] == -1) {
allocParams.emplace_back(create.mem.dim(output, i));
}
}
Value alloc = create.mem.alignedAlloc(outputType, allocParams);
rewriter.create<memref::CopyOp>(loc, output, alloc);
rewriter.replaceOp(op, alloc);
return success();
}
}
};

Expand Down
Loading