Skip to content

Commit

Permalink
[mlir][Transforms] Dialect conversion: Align handling of dropped values
Browse files Browse the repository at this point in the history
Handle dropped block arguments and dropped op results in the same way: build a source materialization (that may fold away if unused). This simplifies the code base a bit and makes it possible to merge `legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes` in a future commit. These two functions are almost doing the same thing now.

This commit also fixes a bug where circular materializations were built, e.g.:
```
%0 = "builtin.unrealized_conversion_cast"(%1) : (!a) -> !b
%1 = "builtin.unrealized_conversion_cast"(%0) : (!b) -> !a
// No further uses of %0, %1.
```

This happened when:
1. An op was erased. (No replacement values provided.)
2. A conversion pattern for another op builds a replacement value (first cast op) during `remapValues`, but that SSA value is not used during the pattern application.
3. During the finalization phase, `legalizeConvertedOpResultTypes` thinks that the erased op is alive because of the cast op that was built in Step 2. It builds a cast from that replacement value to the original type.
4. During the commit phase, all uses of the original op are repalced with the casted value produced in Step 3. We have generated circular IR.
  • Loading branch information
matthias-springer committed Aug 30, 2024
1 parent 57fe53c commit d72b58e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 41 deletions.
62 changes: 23 additions & 39 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;

DenseSet<Operation *> unresolvedMaterializations;

/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
Expand Down Expand Up @@ -1055,6 +1057,7 @@ void UnresolvedMaterializationRewrite::rollback() {
for (Value input : op->getOperands())
rewriterImpl.mapping.erase(input);
}
rewriterImpl.unresolvedMaterializations.erase(op);
op->erase();
}

Expand Down Expand Up @@ -1341,6 +1344,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
unresolvedMaterializations.insert(convertOp);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Expand Down Expand Up @@ -1379,9 +1383,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
// Create mappings for each of the new result values.
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
if (!newValue) {
resultChanged = true;
continue;
// This result was dropped and no replacement value was provided.
if (unresolvedMaterializations.contains(op)) {
// Do not create another materializations if we are erasing a
// materialization.
resultChanged = true;
continue;
}

// Materialize a replacement value "out of thin air".
newValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), currentTypeConverter);
}

// Remap, and check for any result type changes.
mapping.map(result, newValue);
resultChanged |= (newValue.getType() != result.getType());
Expand Down Expand Up @@ -2359,11 +2375,6 @@ struct OperationConverter {
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Legalize an operation result that was marked as "erased".
LogicalResult
legalizeErasedResult(Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl);

/// Dialect conversion configuration.
ConversionConfig config;

Expand Down Expand Up @@ -2500,26 +2511,18 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
continue;
Operation *op = opReplacement->getOperation();
for (OpResult result : op->getResults()) {
Value newValue = rewriterImpl.mapping.lookupOrNull(result);

// If the operation result was replaced with null, all of the uses of this
// value should be replaced.
if (!newValue) {
if (failed(legalizeErasedResult(op, result, rewriterImpl)))
return failure();
// If the type of this op result changed and the result is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
continue;
}

// Otherwise, check to see if the type of the result changed.
if (result.getType() == newValue.getType())
continue;

Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
if (!liveUser)
continue;

// Legalize this result.
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
/*inputs=*/newValue, /*outputType=*/result.getType(),
Expand Down Expand Up @@ -2850,25 +2853,6 @@ LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
return success();
}

LogicalResult OperationConverter::legalizeErasedResult(
Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl) {
// If the operation result was replaced with null, all of the uses of this
// value should be replaced.
auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
if (liveUserIt != result.user_end()) {
InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
<< op->getName() << "' marked as erased";
diag.attachNote(liveUserIt->getLoc())
<< "found live user of result #" << result.getResultNumber() << ": "
<< *liveUserIt;
return failure();
}
return success();
}

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 1 addition & 2 deletions mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
// Test that an error is emitted when an operation is marked as "erased", but
// has users that live across the conversion.
func.func @remove_all_ops(%arg0: i32) -> i32 {
// expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}}
// expected-error@below {{failed to legalize unresolved materialization from () to 'i32' that remained live after conversion}}
%0 = "test.illegal_op_a"() : () -> i32
// expected-note@below {{found live user of result #0: func.return %0 : i32}}
return %0 : i32
}

0 comments on commit d72b58e

Please sign in to comment.