Skip to content

Commit

Permalink
[mlir][Transforms][NFC] Dialect conversion: Eagerly build reverse map…
Browse files Browse the repository at this point in the history
…ping

The "inverse mapping" is an inverse IRMapping that points from replaced values to their original values. This inverse mapping is needed when legalizing unresolved materializations, to figure out if a value has any uses. (It is not sufficient to examine the IR, because some IR changes have not been materialized yet.)

There was a check in `OperationConverter::finalize` that computed the inverse mapping only when needed. This check is not needed. `legalizeUnresolvedMaterializations` always computes the inverse mapping, so we can just do that in `OperationConverter::finalize` before calling `legalizeUnresolvedMaterializations`.
  • Loading branch information
matthias-springer committed Aug 1, 2024
1 parent a0fccfa commit 8752547
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2352,7 +2352,7 @@ struct OperationConverter {
LogicalResult legalizeUnresolvedMaterializations(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Legalize an operation result that was marked as "erased".
LogicalResult
Expand Down Expand Up @@ -2454,10 +2454,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {

LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)))
return failure();

Expand All @@ -2483,15 +2485,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
if (result.getType() == newValue.getType())
continue;

// Compute the inverse mapping only if it is really needed.
if (!inverseMapping)
inverseMapping = rewriterImpl.mapping.getInverse();

// Legalize this result.
rewriter.setInsertionPoint(op);
if (failed(legalizeChangedResultType(
op, result, newValue, opReplacement->getConverter(), rewriter,
rewriterImpl, *inverseMapping)))
rewriterImpl, inverseMapping)))
return failure();
}
}
Expand All @@ -2503,6 +2501,8 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
ConversionPatternRewriterImpl &rewriterImpl) {
// Functor used to check if all users of a value will be dead after
// conversion.
// TODO: This should probably query the inverse mapping, same as in
// `legalizeChangedResultType`.
auto findLiveUser = [&](Value val) {
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
Expand Down Expand Up @@ -2796,20 +2796,18 @@ static LogicalResult legalizeUnresolvedMaterialization(
LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
inverseMapping = rewriterImpl.mapping.getInverse();

DenseMap<Value, SmallVector<Value>> &inverseMapping) {
// As an initial step, compute all of the inserted materializations that we
// expect to persist beyond the conversion process.
DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
*inverseMapping, necessaryMaterializations);
inverseMapping, necessaryMaterializations);

// Once computed, legalize any necessary materializations.
for (auto *mat : necessaryMaterializations) {
if (failed(legalizeUnresolvedMaterialization(
*mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
*mat, materializationOps, rewriter, rewriterImpl, inverseMapping)))
return failure();
}
return success();
Expand Down

0 comments on commit 8752547

Please sign in to comment.