Skip to content

Commit

Permalink
Revert "[mlir][Transforms] Dialect conversion: Build unresolved mater…
Browse files Browse the repository at this point in the history
…ialization for replaced ops (llvm#101514)"

This reverts commit 2d50029.
  • Loading branch information
raikonenfnu committed Aug 16, 2024
1 parent ddda37a commit 8835909
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 57 deletions.
126 changes: 75 additions & 51 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2348,12 +2348,6 @@ struct OperationConverter {
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl);

/// Legalize the types of converted op results.
LogicalResult legalizeConvertedOpResultTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Legalize any unresolved type materializations.
LogicalResult legalizeUnresolvedMaterializations(
ConversionPatternRewriter &rewriter,
Expand All @@ -2365,6 +2359,14 @@ struct OperationConverter {
legalizeErasedResult(Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl);

/// Legalize an operation result that was replaced with a value of a different
/// type.
LogicalResult legalizeChangedResultType(
Operation *op, OpResult result, Value newValue,
const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Dialect conversion configuration.
ConversionConfig config;

Expand Down Expand Up @@ -2457,42 +2459,10 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
return failure();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
inverseMapping)))
return failure();
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)))
return failure();
return success();
}

/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
SmallVector<Value> worklist = {initialValue};
while (!worklist.empty()) {
Value value = worklist.pop_back_val();

// Walk the users of this value to see if there are any live users that
// weren't replaced during conversion.
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
if (liveUserIt != value.user_end())
return *liveUserIt;
auto mapIt = inverseMapping.find(value);
if (mapIt != inverseMapping.end())
worklist.append(mapIt->second);
}
return nullptr;
}

LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
// Process requested operation replacements.
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
auto *opReplacement =
Expand All @@ -2515,21 +2485,14 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
if (result.getType() == newValue.getType())
continue;

Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
if (!liveUser)
continue;

// Legalize this result.
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
/*inputs=*/newValue, /*outputType=*/result.getType(),
opReplacement->getConverter());
rewriterImpl.mapping.map(result, castValue);
inverseMapping[castValue].push_back(result);
rewriter.setInsertionPoint(op);
if (failed(legalizeChangedResultType(
op, result, newValue, opReplacement->getConverter(), rewriter,
rewriterImpl, inverseMapping)))
return failure();
}
}

return success();
}

Expand All @@ -2539,7 +2502,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
// Functor used to check if all users of a value will be dead after
// conversion.
// TODO: This should probably query the inverse mapping, same as in
// `legalizeConvertedOpResultTypes`.
// `legalizeChangedResultType`.
auto findLiveUser = [&](Value val) {
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
Expand Down Expand Up @@ -2869,6 +2832,67 @@ LogicalResult OperationConverter::legalizeErasedResult(
return success();
}

/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
SmallVector<Value> worklist(1, initialValue);
while (!worklist.empty()) {
Value value = worklist.pop_back_val();

// Walk the users of this value to see if there are any live users that
// weren't replaced during conversion.
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
if (liveUserIt != value.user_end())
return *liveUserIt;
auto mapIt = inverseMapping.find(value);
if (mapIt != inverseMapping.end())
worklist.append(mapIt->second);
}
return nullptr;
}

LogicalResult OperationConverter::legalizeChangedResultType(
Operation *op, OpResult result, Value newValue,
const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
if (!liveUser)
return success();

// Functor used to emit a conversion error for a failed materialization.
auto emitConversionError = [&] {
InFlightDiagnostic diag = op->emitError()
<< "failed to materialize conversion for result #"
<< result.getResultNumber() << " of operation '"
<< op->getName()
<< "' that remained live after conversion";
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
};

// If the replacement has a type converter, attempt to materialize a
// conversion back to the original type.
if (!replConverter)
return emitConversionError();

// Materialize a conversion for this live result value.
Type resultType = result.getType();
Value convertedValue = replConverter->materializeSourceConversion(
rewriter, op->getLoc(), resultType, newValue);
if (!convertedValue)
return emitConversionError();

rewriterImpl.mapping.map(result, convertedValue);
return success();
}

//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
// CHECK: return %[[CAST0]], %[[CAST1]]
func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
// memref.cast.
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
%0 = bufferization.to_tensor %m : memref<?xf32>
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
%1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
// expected-note @+1 {{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}

Expand Down
11 changes: 8 additions & 3 deletions mlir/test/Transforms/test-legalize-type-conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
// -----

func.func @test_invalid_result_materialization() {
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16

// expected-note@below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}

// -----

func.func @test_invalid_result_materialization() {
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16

// expected-note@below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}

Expand All @@ -47,8 +51,9 @@ func.func @test_transitive_use_materialization() {
// -----

func.func @test_transitive_use_invalid_materialization() {
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
%result = "test.another_type_producer"() : () -> f16
// expected-note@below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}

Expand Down

0 comments on commit 8835909

Please sign in to comment.