Skip to content

Commit

Permalink
[mlir][Transforms] Dialect conversion: Extra checks during replaceOp
Browse files Browse the repository at this point in the history
This commit adds extra checks/assertions to the `ConversionPatternRewriterImpl::notifyOpReplaced` to improve its robustness.

Replacing an `unrealized_conversion_cast` op that was created by the driver is forbidden and is now caught early during `replaceOp`. It may work in some cases, but is generally dangerous because the conversion driver keeps track of these ops. (Erasing is them is fine.) This change is also in preparation of a subsequent commit that splits the `ConversionValueMapping` into replacements and materializations (with the goal of simplifying block signature conversions).

`null` replacement values are no longer registered in the `ConversionValueMapping`. This was an oversight in #106760. `null` values in the mapping could result in crashes when using the `ConversionValueMapping` API.
  • Loading branch information
matthias-springer committed Sep 29, 2024
1 parent c5d6e99 commit 3399448
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1365,27 +1365,42 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");

// Check if replaced op is an unresolved materialization, i.e., an
// unrealized_conversion_cast op that was created by the conversion driver.
bool isUnresolvedMaterialization = false;
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
if (unresolvedMaterializations.contains(castOp))
isUnresolvedMaterialization = true;

// Create mappings for each of the new result values.
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
if (!newValue) {
// This result was dropped and no replacement value was provided.
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
if (unresolvedMaterializations.contains(castOp)) {
// Do not create another materializations if we are erasing a
// materialization.
continue;
}
if (isUnresolvedMaterialization) {
// Do not create another materializations if we are erasing a
// materialization.
continue;
}

// Materialize a replacement value "out of thin air".
newValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), currentTypeConverter);
} else {
// Make sure that the user does not mess with unresolved materializations
// that were inserted by the conversion driver. We keep track of these
// ops in internal data structures. Erasing them must be allowed because
// this can happen when the user is erasing an entire block (including
// its body). But replacing them with another value should be forbidden
// to avoid problems with the `mapping`.
assert(!isUnresolvedMaterialization &&
"attempting to replace an unresolved materialization");
}

// Remap, and check for any result type changes.
mapping.map(result, newValue);
// Remap result to replacement value.
if (newValue)
mapping.map(result, newValue);
}

appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
Expand Down

0 comments on commit 3399448

Please sign in to comment.