Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[mlir][Transforms] Dialect conversion: Make materializations optional" #106778

Conversation

matthias-springer
Copy link
Member

Reverts #104668

This commit triggers an edge case that can cause circular unrealized_conversion_cast ops. #106760 may fix it, but it is has other issues. Reverting this PR for now, until I find a solution for that problem.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:gpu mlir mlir:bufferization Bufferization infrastructure labels Aug 30, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 30, 2024

@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Matthias Springer (matthias-springer)

Changes

Reverts llvm/llvm-project#104668

This commit triggers an edge case that can cause circular unrealized_conversion_cast ops. #106760 may fix it, but it is has other issues. Reverting this PR for now, until I find a solution for that problem.


Patch is 31.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/106778.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (-11)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+294-99)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+3-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (-1)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+1-5)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5f680e8eca7559..60113bdef16a23 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1124,17 +1124,6 @@ struct ConversionConfig {
   // already been modified) and iterators into past IR state cannot be
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
-
-  /// If set to "true", the dialect conversion attempts to build source/target/
-  /// argument materializations through the type converter API in lieu of
-  /// builtin.unrealized_conversion_cast ops. The conversion process fails if
-  /// at least one materialization could not be built.
-  ///
-  /// If set to "false", the dialect conversion does not does not build any
-  /// custom materializations and instead inserts
-  /// builtin.unrealized_conversion_cast ops to ensure that the resulting IR
-  /// is valid.
-  bool buildMaterializations = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cc9c9495e5155c..b23fb97959ed67 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -702,12 +702,14 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
   }
 
-  void rollback() override;
-
   UnrealizedConversionCastOp getOperation() const {
     return cast<UnrealizedConversionCastOp>(op);
   }
 
+  void rollback() override;
+
+  void cleanup(RewriterBase &rewriter) override;
+
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
     return converterAndKind.getPointer();
@@ -764,7 +766,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -832,7 +834,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   //===--------------------------------------------------------------------===//
   // Materializations
   //===--------------------------------------------------------------------===//
-
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
@@ -881,7 +882,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     /// Erase the given op (unless it was already erased).
     void eraseOp(Operation *op) override {
-      if (wasErased(op))
+      if (erased.contains(op))
         return;
       op->dropAllUses();
       RewriterBase::eraseOp(op);
@@ -889,24 +890,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     /// Erase the given block (unless it was already erased).
     void eraseBlock(Block *block) override {
-      if (wasErased(block))
+      if (erased.contains(block))
         return;
       assert(block->empty() && "expected empty block");
       block->dropAllDefinedValueUses();
       RewriterBase::eraseBlock(block);
     }
 
-    bool wasErased(void *ptr) const { return erased.contains(ptr); }
-
-    bool wasErased(OperationRewrite *rewrite) const {
-      return wasErased(rewrite->getOperation());
-    }
-
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
-  private:
     /// Pointers to all erased operations and blocks.
     DenseSet<void *> erased;
   };
@@ -918,11 +912,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
-  /// A rewriter that keeps track of ops/block that were already erased and
-  /// skips duplicate op/block erasures. This rewriter is used during the
-  /// "cleanup" phase.
-  SingleEraseRewriter eraseRewriter;
-
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1069,6 +1058,10 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
+
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
   IRRewriter rewriter(context, config.listener);
@@ -1076,6 +1069,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit(rewriter);
 
   // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2359,6 +2353,12 @@ struct OperationConverter {
       ConversionPatternRewriterImpl &rewriterImpl,
       DenseMap<Value, SmallVector<Value>> &inverseMapping);
 
+  /// Legalize any unresolved type materializations.
+  LogicalResult legalizeUnresolvedMaterializations(
+      ConversionPatternRewriter &rewriter,
+      ConversionPatternRewriterImpl &rewriterImpl,
+      DenseMap<Value, SmallVector<Value>> &inverseMapping);
+
   /// Legalize an operation result that was marked as "erased".
   LogicalResult
   legalizeErasedResult(Operation *op, OpResult result,
@@ -2405,56 +2405,6 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
   return success();
 }
 
-static LogicalResult
-legalizeUnresolvedMaterialization(RewriterBase &rewriter,
-                                  UnresolvedMaterializationRewrite *rewrite) {
-  UnrealizedConversionCastOp op = rewrite->getOperation();
-  assert(!op.use_empty() &&
-         "expected that dead materializations have already been DCE'd");
-  Operation::operand_range inputOperands = op.getOperands();
-  Type outputType = op.getResultTypes()[0];
-
-  // Try to materialize the conversion.
-  if (const TypeConverter *converter = rewrite->getConverter()) {
-    rewriter.setInsertionPoint(op);
-    Value newMaterialization;
-    switch (rewrite->getMaterializationKind()) {
-    case MaterializationKind::Argument:
-      // Try to materialize an argument conversion.
-      newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      if (newMaterialization)
-        break;
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
-    case MaterializationKind::Target:
-      newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    case MaterializationKind::Source:
-      newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    }
-    if (newMaterialization) {
-      assert(newMaterialization.getType() == outputType &&
-             "materialization callback produced value of incorrect type");
-      rewriter.replaceOp(op, newMaterialization);
-      return success();
-    }
-  }
-
-  InFlightDiagnostic diag = op->emitError()
-                            << "failed to legalize unresolved materialization "
-                               "from ("
-                            << inputOperands.getTypes() << ") to " << outputType
-                            << " that remained live after conversion";
-  diag.attachNote(op->getUsers().begin()->getLoc())
-      << "see existing live user here: " << *op->getUsers().begin();
-  return failure();
-}
-
 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
@@ -2496,37 +2446,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   } else {
     rewriterImpl.applyRewrites();
   }
-
-  // Gather all unresolved materializations.
-  SmallVector<UnrealizedConversionCastOp> allCastOps;
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
-  for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    if (rewriterImpl.eraseRewriter.wasErased(mat))
-      continue;
-    allCastOps.push_back(mat->getOperation());
-    rewriteMap[mat->getOperation()] = mat;
-  }
-
-  // Reconcile all UnrealizedConversionCastOps that were inserted by the
-  // dialect conversion frameworks. (Not the one that were inserted by
-  // patterns.)
-  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
-  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
-
-  // Try to legalize all unresolved materializations.
-  if (config.buildMaterializations) {
-    IRRewriter rewriter(rewriterImpl.context, config.listener);
-    for (UnrealizedConversionCastOp castOp : remainingCastOps) {
-      auto it = rewriteMap.find(castOp.getOperation());
-      assert(it != rewriteMap.end() && "inconsistent state");
-      if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
-        return failure();
-    }
-  }
-
   return success();
 }
 
@@ -2540,6 +2459,9 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
                                             inverseMapping)))
     return failure();
+  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+                                                inverseMapping)))
+    return failure();
   return success();
 }
 
@@ -2655,6 +2577,279 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
   return success();
 }
 
+/// Replace the results of a materialization operation with the given values.
+static void
+replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
+                       ResultRange matResults, ValueRange values,
+                       DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+  matResults.replaceAllUsesWith(values);
+
+  // For each of the materialization results, update the inverse mappings to
+  // point to the replacement values.
+  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
+    auto inverseMapIt = inverseMapping.find(matResult);
+    if (inverseMapIt == inverseMapping.end())
+      continue;
+
+    // Update the reverse mapping, or remove the mapping if we couldn't update
+    // it. Not being able to update signals that the mapping would have become
+    // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
+    // propagated through temporary materializations. We simply drop the
+    // mapping, and let the post-conversion replacement logic handle updating
+    // uses.
+    for (Value inverseMapVal : inverseMapIt->second)
+      if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
+        rewriterImpl.mapping.erase(inverseMapVal);
+  }
+}
+
+/// Compute all of the unresolved materializations that will persist beyond the
+/// conversion process, and require inserting a proper user materialization for.
+static void computeNecessaryMaterializations(
+    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+        &materializationOps,
+    ConversionPatternRewriter &rewriter,
+    ConversionPatternRewriterImpl &rewriterImpl,
+    DenseMap<Value, SmallVector<Value>> &inverseMapping,
+    SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
+  // Helper function to check if the given value or a not yet materialized
+  // replacement of the given value is live.
+  // Note: `inverseMapping` maps from replaced values to original values.
+  auto isLive = [&](Value value) {
+    auto findFn = [&](Operation *user) {
+      auto matIt = materializationOps.find(user);
+      if (matIt != materializationOps.end())
+        return !necessaryMaterializations.count(matIt->second);
+      return rewriterImpl.isOpIgnored(user);
+    };
+    // A worklist is needed because a value may have gone through a chain of
+    // replacements and each of the replaced values may have live users.
+    SmallVector<Value> worklist;
+    worklist.push_back(value);
+    while (!worklist.empty()) {
+      Value next = worklist.pop_back_val();
+      if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
+        return true;
+      // This value may be replacing another value that has a live user.
+      llvm::append_range(worklist, inverseMapping.lookup(next));
+    }
+    return false;
+  };
+
+  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
+      [&](Value invalidRoot, Value value, Type type) {
+        // Check to see if the input operation was remapped to a variant of the
+        // output.
+        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
+        if (remappedValue.getType() == type && remappedValue != invalidRoot)
+          return remappedValue;
+
+        // Check to see if the input is a materialization operation that
+        // provides an inverse conversion. We just check blindly for
+        // UnrealizedConversionCastOp here, but it has no effect on correctness.
+        auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
+        if (inputCastOp && inputCastOp->getNumOperands() == 1)
+          return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
+                                     type);
+
+        return Value();
+      };
+
+  SetVector<UnresolvedMaterializationRewrite *> worklist;
+  for (auto &rewrite : rewriterImpl.rewrites) {
+    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
+    if (!mat)
+      continue;
+    materializationOps.try_emplace(mat->getOperation(), mat);
+    worklist.insert(mat);
+  }
+  while (!worklist.empty()) {
+    UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
+    UnrealizedConversionCastOp op = mat->getOperation();
+
+    // We currently only handle target materializations here.
+    assert(op->getNumResults() == 1 && "unexpected materialization type");
+    OpResult opResult = op->getOpResult(0);
+    Type outputType = opResult.getType();
+    Operation::operand_range inputOperands = op.getOperands();
+
+    // Try to forward propagate operands for user conversion casts that result
+    // in the input types of the current cast.
+    for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
+      auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
+      if (!castOp)
+        continue;
+      if (castOp->getResultTypes() == inputOperands.getTypes()) {
+        replaceMaterialization(rewriterImpl, user->getResults(), inputOperands,
+                               inverseMapping);
+        necessaryMaterializations.remove(materializationOps.lookup(user));
+      }
+    }
+
+    // Try to avoid materializing a resolved materialization if possible.
+    // Handle the case of a 1-1 materialization.
+    if (inputOperands.size() == 1) {
+      // Check to see if the input operation was remapped to a variant of the
+      // output.
+      Value remappedValue =
+          lookupRemappedValue(opResult, inputOperands[0], outputType);
+      if (remappedValue && remappedValue != opResult) {
+        replaceMaterialization(rewriterImpl, opResult, remappedValue,
+                               inverseMapping);
+        necessaryMaterializations.remove(mat);
+        continue;
+      }
+    } else {
+      // TODO: Avoid materializing other types of conversions here.
+    }
+
+    // If the materialization does not have any live users, we don't need to
+    // generate a user materialization for it.
+    bool isMaterializationLive = isLive(opResult);
+    if (!isMaterializationLive)
+      continue;
+    if (!necessaryMaterializations.insert(mat))
+      continue;
+
+    // Reprocess input materializations to see if they have an updated status.
+    for (Value input : inputOperands) {
+      if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
+        if (auto *mat = materializationOps.lookup(parentOp))
+          worklist.insert(mat);
+      }
+    }
+  }
+}
+
+/// Legalize the given unresolved materialization. Returns success if the
+/// materialization was legalized, failure otherise.
+static LogicalResult legalizeUnresolvedMaterialization(
+    UnresolvedMaterializationRewrite &mat,
+    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+        &materializationOps,
+    ConversionPatternRewriter &rewriter,
+    ConversionPatternRewriterImpl &rewriterImpl,
+    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+  auto findLiveUser = [&](auto &&users) {
+    auto liveUserIt = llvm::find_if_not(
+        users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
+    return liveUserIt == users.end() ? nullptr : *liveUserIt;
+  };
+
+  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
+      [&](Value value, Type type) {
+        // Check to see if the input operation was remapped to a variant of the
+        // output.
+        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
+        if (remappedValue.getType() == type)
+          return remappedValue;
+        return Value();
+      };
+
+  UnrealizedConversionCastOp op = mat.getOperation();
+  if (!rewriterImpl.ignoredOps.insert(op))
+    return success();
+
+  // We currently only handle target materializations here.
+  OpResult opResult = op->getOpResult(0);
+  Operation::operand_range inputOperands = op.getOperands();
+  Type outputType = opResult.getType();
+
+  // If any input to this materialization is another materialization, resolve
+  // the input first.
+  for (Value value : op->getOperands()) {
+    auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
+    if (!valueCast)
+      continue;
+
+    auto matIt = materializationOps.find(valueCast);
+    if (matIt != materializationOps.end())
+      if (failed(legalizeUnresolvedMaterialization(
+              *matIt->second, materializationOps, rewriter, rewriterImpl,
+              inverseMapping)))
+        return failure();
+  }
+
+  // Perform a last ditch attempt to avoid materializing a resolved
+  // materialization if possible.
+  // Handle the case of a 1-1 materialization.
+  if (inputOperands.size() == 1) {
+    // Check to see if the input operation was remapped to a variant of the
+    // output.
+    Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
+    if (remappedValue && remappedValue != opResult) {
+      replaceMaterialization(rewriterImpl, opResult, remappedValue,
+                             inverseMapping);
+      return success();
+    }
+  } else {
+    // TODO: Avoid materializing other types of conversions here.
+  }
+
+  // Try to materialize the conversion.
+  if (const TypeConverter *converter = mat.getConverter()) {
+    rewriter.setInsertionPoint(op);
+    Value newMaterialization;
+    switch (mat.getMaterializationKind()) {
+    case MaterializationKind::Argument:
+      // Try to materialize an argument conversion.
+      newMaterialization = converter->materializeArgumentConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      if (newMaterialization)
+        break;
+      // If an argument materialization failed, fallback to trying a target
+      // materialization.
+      [[fallthrough]];
+    case MaterializationKind::Target:
+      newMaterialization = converter->materializeTargetConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    }
+    if (newMaterialization) {
+      assert(newMaterialization.getType() == outputType &&
+             "materialization callback produced value of incorrect type");
+      replaceMaterial...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 30, 2024

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Reverts llvm/llvm-project#104668

This commit triggers an edge case that can cause circular unrealized_conversion_cast ops. #106760 may fix it, but it is has other issues. Reverting this PR for now, until I find a solution for that problem.


Patch is 31.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/106778.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (-11)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+294-99)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+3-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (-1)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+1-5)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5f680e8eca7559..60113bdef16a23 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1124,17 +1124,6 @@ struct ConversionConfig {
   // already been modified) and iterators into past IR state cannot be
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
-
-  /// If set to "true", the dialect conversion attempts to build source/target/
-  /// argument materializations through the type converter API in lieu of
-  /// builtin.unrealized_conversion_cast ops. The conversion process fails if
-  /// at least one materialization could not be built.
-  ///
-  /// If set to "false", the dialect conversion does not does not build any
-  /// custom materializations and instead inserts
-  /// builtin.unrealized_conversion_cast ops to ensure that the resulting IR
-  /// is valid.
-  bool buildMaterializations = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cc9c9495e5155c..b23fb97959ed67 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -702,12 +702,14 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
   }
 
-  void rollback() override;
-
   UnrealizedConversionCastOp getOperation() const {
     return cast<UnrealizedConversionCastOp>(op);
   }
 
+  void rollback() override;
+
+  void cleanup(RewriterBase &rewriter) override;
+
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
     return converterAndKind.getPointer();
@@ -764,7 +766,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -832,7 +834,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   //===--------------------------------------------------------------------===//
   // Materializations
   //===--------------------------------------------------------------------===//
-
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
@@ -881,7 +882,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     /// Erase the given op (unless it was already erased).
     void eraseOp(Operation *op) override {
-      if (wasErased(op))
+      if (erased.contains(op))
         return;
       op->dropAllUses();
       RewriterBase::eraseOp(op);
@@ -889,24 +890,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     /// Erase the given block (unless it was already erased).
     void eraseBlock(Block *block) override {
-      if (wasErased(block))
+      if (erased.contains(block))
         return;
       assert(block->empty() && "expected empty block");
       block->dropAllDefinedValueUses();
       RewriterBase::eraseBlock(block);
     }
 
-    bool wasErased(void *ptr) const { return erased.contains(ptr); }
-
-    bool wasErased(OperationRewrite *rewrite) const {
-      return wasErased(rewrite->getOperation());
-    }
-
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
-  private:
     /// Pointers to all erased operations and blocks.
     DenseSet<void *> erased;
   };
@@ -918,11 +912,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
-  /// A rewriter that keeps track of ops/block that were already erased and
-  /// skips duplicate op/block erasures. This rewriter is used during the
-  /// "cleanup" phase.
-  SingleEraseRewriter eraseRewriter;
-
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1069,6 +1058,10 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
+
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
   IRRewriter rewriter(context, config.listener);
@@ -1076,6 +1069,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit(rewriter);
 
   // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2359,6 +2353,12 @@ struct OperationConverter {
       ConversionPatternRewriterImpl &rewriterImpl,
       DenseMap<Value, SmallVector<Value>> &inverseMapping);
 
+  /// Legalize any unresolved type materializations.
+  LogicalResult legalizeUnresolvedMaterializations(
+      ConversionPatternRewriter &rewriter,
+      ConversionPatternRewriterImpl &rewriterImpl,
+      DenseMap<Value, SmallVector<Value>> &inverseMapping);
+
   /// Legalize an operation result that was marked as "erased".
   LogicalResult
   legalizeErasedResult(Operation *op, OpResult result,
@@ -2405,56 +2405,6 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
   return success();
 }
 
-static LogicalResult
-legalizeUnresolvedMaterialization(RewriterBase &rewriter,
-                                  UnresolvedMaterializationRewrite *rewrite) {
-  UnrealizedConversionCastOp op = rewrite->getOperation();
-  assert(!op.use_empty() &&
-         "expected that dead materializations have already been DCE'd");
-  Operation::operand_range inputOperands = op.getOperands();
-  Type outputType = op.getResultTypes()[0];
-
-  // Try to materialize the conversion.
-  if (const TypeConverter *converter = rewrite->getConverter()) {
-    rewriter.setInsertionPoint(op);
-    Value newMaterialization;
-    switch (rewrite->getMaterializationKind()) {
-    case MaterializationKind::Argument:
-      // Try to materialize an argument conversion.
-      newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      if (newMaterialization)
-        break;
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
-    case MaterializationKind::Target:
-      newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    case MaterializationKind::Source:
-      newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    }
-    if (newMaterialization) {
-      assert(newMaterialization.getType() == outputType &&
-             "materialization callback produced value of incorrect type");
-      rewriter.replaceOp(op, newMaterialization);
-      return success();
-    }
-  }
-
-  InFlightDiagnostic diag = op->emitError()
-                            << "failed to legalize unresolved materialization "
-                               "from ("
-                            << inputOperands.getTypes() << ") to " << outputType
-                            << " that remained live after conversion";
-  diag.attachNote(op->getUsers().begin()->getLoc())
-      << "see existing live user here: " << *op->getUsers().begin();
-  return failure();
-}
-
 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
@@ -2496,37 +2446,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   } else {
     rewriterImpl.applyRewrites();
   }
-
-  // Gather all unresolved materializations.
-  SmallVector<UnrealizedConversionCastOp> allCastOps;
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
-  for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    if (rewriterImpl.eraseRewriter.wasErased(mat))
-      continue;
-    allCastOps.push_back(mat->getOperation());
-    rewriteMap[mat->getOperation()] = mat;
-  }
-
-  // Reconcile all UnrealizedConversionCastOps that were inserted by the
-  // dialect conversion frameworks. (Not the one that were inserted by
-  // patterns.)
-  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
-  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
-
-  // Try to legalize all unresolved materializations.
-  if (config.buildMaterializations) {
-    IRRewriter rewriter(rewriterImpl.context, config.listener);
-    for (UnrealizedConversionCastOp castOp : remainingCastOps) {
-      auto it = rewriteMap.find(castOp.getOperation());
-      assert(it != rewriteMap.end() && "inconsistent state");
-      if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
-        return failure();
-    }
-  }
-
   return success();
 }
 
@@ -2540,6 +2459,9 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
                                             inverseMapping)))
     return failure();
+  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+                                                inverseMapping)))
+    return failure();
   return success();
 }
 
@@ -2655,6 +2577,279 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
   return success();
 }
 
+/// Replace the results of a materialization operation with the given values.
+static void
+replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
+                       ResultRange matResults, ValueRange values,
+                       DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+  matResults.replaceAllUsesWith(values);
+
+  // For each of the materialization results, update the inverse mappings to
+  // point to the replacement values.
+  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
+    auto inverseMapIt = inverseMapping.find(matResult);
+    if (inverseMapIt == inverseMapping.end())
+      continue;
+
+    // Update the reverse mapping, or remove the mapping if we couldn't update
+    // it. Not being able to update signals that the mapping would have become
+    // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
+    // propagated through temporary materializations. We simply drop the
+    // mapping, and let the post-conversion replacement logic handle updating
+    // uses.
+    for (Value inverseMapVal : inverseMapIt->second)
+      if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
+        rewriterImpl.mapping.erase(inverseMapVal);
+  }
+}
+
+/// Compute all of the unresolved materializations that will persist beyond the
+/// conversion process, and require inserting a proper user materialization for.
+static void computeNecessaryMaterializations(
+    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+        &materializationOps,
+    ConversionPatternRewriter &rewriter,
+    ConversionPatternRewriterImpl &rewriterImpl,
+    DenseMap<Value, SmallVector<Value>> &inverseMapping,
+    SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
+  // Helper function to check if the given value or a not yet materialized
+  // replacement of the given value is live.
+  // Note: `inverseMapping` maps from replaced values to original values.
+  auto isLive = [&](Value value) {
+    auto findFn = [&](Operation *user) {
+      auto matIt = materializationOps.find(user);
+      if (matIt != materializationOps.end())
+        return !necessaryMaterializations.count(matIt->second);
+      return rewriterImpl.isOpIgnored(user);
+    };
+    // A worklist is needed because a value may have gone through a chain of
+    // replacements and each of the replaced values may have live users.
+    SmallVector<Value> worklist;
+    worklist.push_back(value);
+    while (!worklist.empty()) {
+      Value next = worklist.pop_back_val();
+      if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
+        return true;
+      // This value may be replacing another value that has a live user.
+      llvm::append_range(worklist, inverseMapping.lookup(next));
+    }
+    return false;
+  };
+
+  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
+      [&](Value invalidRoot, Value value, Type type) {
+        // Check to see if the input operation was remapped to a variant of the
+        // output.
+        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
+        if (remappedValue.getType() == type && remappedValue != invalidRoot)
+          return remappedValue;
+
+        // Check to see if the input is a materialization operation that
+        // provides an inverse conversion. We just check blindly for
+        // UnrealizedConversionCastOp here, but it has no effect on correctness.
+        auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
+        if (inputCastOp && inputCastOp->getNumOperands() == 1)
+          return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
+                                     type);
+
+        return Value();
+      };
+
+  SetVector<UnresolvedMaterializationRewrite *> worklist;
+  for (auto &rewrite : rewriterImpl.rewrites) {
+    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
+    if (!mat)
+      continue;
+    materializationOps.try_emplace(mat->getOperation(), mat);
+    worklist.insert(mat);
+  }
+  while (!worklist.empty()) {
+    UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
+    UnrealizedConversionCastOp op = mat->getOperation();
+
+    // We currently only handle target materializations here.
+    assert(op->getNumResults() == 1 && "unexpected materialization type");
+    OpResult opResult = op->getOpResult(0);
+    Type outputType = opResult.getType();
+    Operation::operand_range inputOperands = op.getOperands();
+
+    // Try to forward propagate operands for user conversion casts that result
+    // in the input types of the current cast.
+    for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
+      auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
+      if (!castOp)
+        continue;
+      if (castOp->getResultTypes() == inputOperands.getTypes()) {
+        replaceMaterialization(rewriterImpl, user->getResults(), inputOperands,
+                               inverseMapping);
+        necessaryMaterializations.remove(materializationOps.lookup(user));
+      }
+    }
+
+    // Try to avoid materializing a resolved materialization if possible.
+    // Handle the case of a 1-1 materialization.
+    if (inputOperands.size() == 1) {
+      // Check to see if the input operation was remapped to a variant of the
+      // output.
+      Value remappedValue =
+          lookupRemappedValue(opResult, inputOperands[0], outputType);
+      if (remappedValue && remappedValue != opResult) {
+        replaceMaterialization(rewriterImpl, opResult, remappedValue,
+                               inverseMapping);
+        necessaryMaterializations.remove(mat);
+        continue;
+      }
+    } else {
+      // TODO: Avoid materializing other types of conversions here.
+    }
+
+    // If the materialization does not have any live users, we don't need to
+    // generate a user materialization for it.
+    bool isMaterializationLive = isLive(opResult);
+    if (!isMaterializationLive)
+      continue;
+    if (!necessaryMaterializations.insert(mat))
+      continue;
+
+    // Reprocess input materializations to see if they have an updated status.
+    for (Value input : inputOperands) {
+      if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
+        if (auto *mat = materializationOps.lookup(parentOp))
+          worklist.insert(mat);
+      }
+    }
+  }
+}
+
+/// Legalize the given unresolved materialization. Returns success if the
+/// materialization was legalized, failure otherise.
+static LogicalResult legalizeUnresolvedMaterialization(
+    UnresolvedMaterializationRewrite &mat,
+    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+        &materializationOps,
+    ConversionPatternRewriter &rewriter,
+    ConversionPatternRewriterImpl &rewriterImpl,
+    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+  auto findLiveUser = [&](auto &&users) {
+    auto liveUserIt = llvm::find_if_not(
+        users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
+    return liveUserIt == users.end() ? nullptr : *liveUserIt;
+  };
+
+  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
+      [&](Value value, Type type) {
+        // Check to see if the input operation was remapped to a variant of the
+        // output.
+        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
+        if (remappedValue.getType() == type)
+          return remappedValue;
+        return Value();
+      };
+
+  UnrealizedConversionCastOp op = mat.getOperation();
+  if (!rewriterImpl.ignoredOps.insert(op))
+    return success();
+
+  // We currently only handle target materializations here.
+  OpResult opResult = op->getOpResult(0);
+  Operation::operand_range inputOperands = op.getOperands();
+  Type outputType = opResult.getType();
+
+  // If any input to this materialization is another materialization, resolve
+  // the input first.
+  for (Value value : op->getOperands()) {
+    auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
+    if (!valueCast)
+      continue;
+
+    auto matIt = materializationOps.find(valueCast);
+    if (matIt != materializationOps.end())
+      if (failed(legalizeUnresolvedMaterialization(
+              *matIt->second, materializationOps, rewriter, rewriterImpl,
+              inverseMapping)))
+        return failure();
+  }
+
+  // Perform a last ditch attempt to avoid materializing a resolved
+  // materialization if possible.
+  // Handle the case of a 1-1 materialization.
+  if (inputOperands.size() == 1) {
+    // Check to see if the input operation was remapped to a variant of the
+    // output.
+    Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
+    if (remappedValue && remappedValue != opResult) {
+      replaceMaterialization(rewriterImpl, opResult, remappedValue,
+                             inverseMapping);
+      return success();
+    }
+  } else {
+    // TODO: Avoid materializing other types of conversions here.
+  }
+
+  // Try to materialize the conversion.
+  if (const TypeConverter *converter = mat.getConverter()) {
+    rewriter.setInsertionPoint(op);
+    Value newMaterialization;
+    switch (mat.getMaterializationKind()) {
+    case MaterializationKind::Argument:
+      // Try to materialize an argument conversion.
+      newMaterialization = converter->materializeArgumentConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      if (newMaterialization)
+        break;
+      // If an argument materialization failed, fallback to trying a target
+      // materialization.
+      [[fallthrough]];
+    case MaterializationKind::Target:
+      newMaterialization = converter->materializeTargetConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    }
+    if (newMaterialization) {
+      assert(newMaterialization.getType() == outputType &&
+             "materialization callback produced value of incorrect type");
+      replaceMaterial...
[truncated]

@matthias-springer matthias-springer merged commit 5eda498 into main Aug 30, 2024
11 of 12 checks passed
@matthias-springer matthias-springer deleted the revert-104668-users/matthias-springer/optional_materializations branch August 30, 2024 19:34
@maerhart maerhart mentioned this pull request Sep 3, 2024
@maerhart
Copy link
Member

maerhart commented Sep 3, 2024

@matthias-springer Do you have a timeline for the mentioned solution? If it takes longer (>1 week), would it be possible to revert #101514 as well (because of the issue @mikeurbach already mentioned there)?

@matthias-springer
Copy link
Member Author

Oh I forgot that this fixed another issue in your code base. There was something that I wanted to try: #104668 (comment). Let me implement that and see if it works, then I can land the PR again. It should be a pretty small change to the original PR.

@matthias-springer
Copy link
Member Author

I just went back to the discussion on #101514. It sounded like #104668 would fix your issue. (That PR has landed.) Is that not the case?

@matthias-springer
Copy link
Member Author

I think this should be fine: #107109. (A test case is still missing.) This passes with MLIR. Also tried it with IREE with a few minor integration fixes and I don't see any failure (but I have not run the CI yet). Ideally, I'd like to also run it by Google to see if something is breaking.

Is there an easy way to reproduce your issue? Maybe we can just fix it quickly to give people more time to take a look at #107109.

@maerhart
Copy link
Member

maerhart commented Sep 4, 2024

Thanks a lot for working on this!

I just went back to the discussion on #101514. It sounded like #104668 would fix your issue. (That PR has landed.) Is that not the case?

Yes, #104668 fixed the issue. The problem is just that this PR reverted that fix.

I checked out #107109, and all our tests pass on top of that branch. It's not very time-critical; we integrate upstream weekly, but delaying it by a few days should be fine.

I don't have an upstream example to reproduce this, but we can return to it if necessary.

As an example, the following fails with a "null operands" error (I think for an op created during type conversion materialization, maybe because the %1 operand still had the source type when the add was converted and later gets replaced with the result of the converted hwarith.constant op, but in some way not correctly?)

  hw.module @backedges() {
    // graph region
    %0 = hwarith.add %1, %1 : (ui1, ui1) -> ui2
    %1 = hwarith.constant 1 : ui1
    hw.output
  }

If the order of the two operations in the body is changed, it succeeds and returns the following:

  hw.module @backedges() {
    %true = hw.constant true
    %false = hw.constant false
    %0 = comb.concat %false, %true : i1, i1
    %false_0 = hw.constant false
    %1 = comb.concat %false_0, %true : i1, i1
    %2 = comb.add %0, %1 : i2
    hw.output
  }

command: circt-opt -lower-hwarith-to-hw input.mlir

@matthias-springer
Copy link
Member Author

Oh I got confused with the PR numbers... I heard back from the Google and IREE folks and #107109 is safe to land from their side. I will try to add a test for the cyclic cast ops (this evening or tomorrow), and then this is good to merge from my side.

@matthias-springer
Copy link
Member Author

matthias-springer commented Sep 5, 2024

I added a test case. #107109 is ready to land from my perspective. I added you as reviewers. (The original PR was already reviewed, but I added a bit of code, so it's better to have it reviewed again. The part that's new is eraseDeadUnrealizedCasts and the test case.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:gpu mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants