Skip to content

Commit

Permalink
[mlir][Transforms] Dialect conversion: Build unresolved materializati…
Browse files Browse the repository at this point in the history
…on for replaced ops

When inserting an argument/source/target materialization, the dialect conversion framework first inserts a "dummy" `unrealized_conversion_cast` op (during the rewrite process) and then (in the "finialize" phase) replaces these cast ops with the IR generated by the type converter callback.

This is the case for all materializations, except when ops are being replaced with values that have a different type. In that case, the dialect conversion currently directly emits a source materialization. This commit changes the implementation, such that a temporary `unrealized_conversion_cast` is also inserted in this case.

This commit simplifies the code base: all materializations now happen in `legalizeUnresolvedMaterialization`. This commit makes it possible to decouple source/target/argument materializations from the dialect conversion (to reduce the complexity of the code base). Such materializations can then also be optional. This will be implemented in a follow-up commit.
  • Loading branch information
matthias-springer committed Aug 3, 2024
1 parent 82afd9d commit 2d9d128
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 87 deletions.
126 changes: 51 additions & 75 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,12 @@ struct OperationConverter {
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl);

/// Legalize the types of converted op results.
LogicalResult legalizeConvertedOpResultTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Legalize any unresolved type materializations.
LogicalResult legalizeUnresolvedMaterializations(
ConversionPatternRewriter &rewriter,
Expand All @@ -2359,14 +2365,6 @@ struct OperationConverter {
legalizeErasedResult(Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl);

/// Legalize an operation result that was replaced with a value of a different
/// type.
LogicalResult legalizeChangedResultType(
Operation *op, OpResult result, Value newValue,
const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Dialect conversion configuration.
ConversionConfig config;

Expand Down Expand Up @@ -2459,10 +2457,42 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
return failure();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
inverseMapping)))
return failure();
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)))
return failure();
return success();
}

/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
SmallVector<Value> worklist(1, initialValue);
while (!worklist.empty()) {
Value value = worklist.pop_back_val();

// Walk the users of this value to see if there are any live users that
// weren't replaced during conversion.
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
if (liveUserIt != value.user_end())
return *liveUserIt;
auto mapIt = inverseMapping.find(value);
if (mapIt != inverseMapping.end())
worklist.append(mapIt->second);
}
return nullptr;
}

LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
// Process requested operation replacements.
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
auto *opReplacement =
Expand All @@ -2485,14 +2515,21 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
if (result.getType() == newValue.getType())
continue;

Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
if (!liveUser)
continue;

// Legalize this result.
rewriter.setInsertionPoint(op);
if (failed(legalizeChangedResultType(
op, result, newValue, opReplacement->getConverter(), rewriter,
rewriterImpl, inverseMapping)))
return failure();
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
/*inputs=*/newValue, /*outputType=*/result.getType(),
opReplacement->getConverter());
rewriterImpl.mapping.map(result, castValue);
inverseMapping[castValue].push_back(result);
}
}

return success();
}

Expand All @@ -2502,7 +2539,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
// Functor used to check if all users of a value will be dead after
// conversion.
// TODO: This should probably query the inverse mapping, same as in
// `legalizeChangedResultType`.
// `legalizeConvertedOpResultTypes`.
auto findLiveUser = [&](Value val) {
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
Expand Down Expand Up @@ -2832,67 +2869,6 @@ LogicalResult OperationConverter::legalizeErasedResult(
return success();
}

/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
SmallVector<Value> worklist(1, initialValue);
while (!worklist.empty()) {
Value value = worklist.pop_back_val();

// Walk the users of this value to see if there are any live users that
// weren't replaced during conversion.
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
if (liveUserIt != value.user_end())
return *liveUserIt;
auto mapIt = inverseMapping.find(value);
if (mapIt != inverseMapping.end())
worklist.append(mapIt->second);
}
return nullptr;
}

LogicalResult OperationConverter::legalizeChangedResultType(
Operation *op, OpResult result, Value newValue,
const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
if (!liveUser)
return success();

// Functor used to emit a conversion error for a failed materialization.
auto emitConversionError = [&] {
InFlightDiagnostic diag = op->emitError()
<< "failed to materialize conversion for result #"
<< result.getResultNumber() << " of operation '"
<< op->getName()
<< "' that remained live after conversion";
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
};

// If the replacement has a type converter, attempt to materialize a
// conversion back to the original type.
if (!replConverter)
return emitConversionError();

// Materialize a conversion for this live result value.
Type resultType = result.getType();
Value convertedValue = replConverter->materializeSourceConversion(
rewriter, op->getLoc(), resultType, newValue);
if (!convertedValue)
return emitConversionError();

rewriterImpl.mapping.map(result, convertedValue);
return success();
}

//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
// CHECK: return %[[CAST0]], %[[CAST1]]
func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
// memref.cast.
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
%0 = bufferization.to_tensor %m : memref<?xf32>
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
%1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
// expected-note @+1 {{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}

Expand Down
11 changes: 3 additions & 8 deletions mlir/test/Transforms/test-legalize-type-conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
// -----

func.func @test_invalid_result_materialization() {
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16

// expected-note@below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}

// -----

func.func @test_invalid_result_materialization() {
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16

// expected-note@below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}

Expand All @@ -51,9 +47,8 @@ func.func @test_transitive_use_materialization() {
// -----

func.func @test_transitive_use_invalid_materialization() {
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.another_type_producer"() : () -> f16
// expected-note@below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}

Expand Down

0 comments on commit 2d9d128

Please sign in to comment.