Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MooreToCore] Alternative conditional lowering and type conversion fixes #7625

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 97 additions & 65 deletions lib/Conversion/MooreToCore/MooreToCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,40 @@ struct ConditionalOpConversion : public OpConversionPattern<ConditionalOp> {
// evaluated and merged with the appropriate lookup table. See documentation
// for `ConditionalOp`.
auto type = typeConverter->convertType(op.getType());

auto hasNoWriteEffect = [](Region &region) {
auto result = region.walk([](Operation *operation) {
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(operation))
if (!memOp.hasEffect<MemoryEffects::Write>() &&
!memOp.hasEffect<MemoryEffects::Free>())
return WalkResult::advance();

return WalkResult::interrupt();
});
return !result.wasInterrupted();
};

if (hasNoWriteEffect(op.getTrueRegion()) &&
hasNoWriteEffect(op.getFalseRegion())) {
Operation *trueTerm = op.getTrueRegion().front().getTerminator();
Operation *falseTerm = op.getFalseRegion().front().getTerminator();

rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);

Value convTrueVal = typeConverter->materializeTargetConversion(
rewriter, op.getLoc(), type, trueTerm->getOperand(0));
Value convFalseVal = typeConverter->materializeTargetConversion(
rewriter, op.getLoc(), type, falseTerm->getOperand(0));

rewriter.eraseOp(trueTerm);
rewriter.eraseOp(falseTerm);

rewriter.replaceOpWithNewOp<comb::MuxOp>(op, adaptor.getCondition(),
convTrueVal, convFalseVal);
return success();
}

auto ifOp =
rewriter.create<scf::IfOp>(op.getLoc(), type, adaptor.getCondition());
rewriter.inlineRegionBefore(op.getTrueRegion(), ifOp.getThenRegion(),
Expand Down Expand Up @@ -1198,80 +1232,51 @@ struct InPlaceOpConversion : public OpConversionPattern<SourceOp> {
// Conversion Infrastructure
//===----------------------------------------------------------------------===//

static bool isMooreType(Type type) { return isa<UnpackedType>(type); }

static bool hasMooreType(TypeRange types) {
return llvm::any_of(types, isMooreType);
}

static bool hasMooreType(ValueRange values) {
return hasMooreType(values.getTypes());
}

template <typename Op>
static void addGenericLegality(ConversionTarget &target) {
target.addDynamicallyLegalOp<Op>([](Op op) {
return !hasMooreType(op->getOperands()) && !hasMooreType(op->getResults());
});
}

static void populateLegality(ConversionTarget &target) {
static void populateLegality(ConversionTarget &target,
const TypeConverter &converter) {
target.addIllegalDialect<MooreDialect>();
target.addLegalDialect<mlir::BuiltinDialect>();
target.addLegalDialect<hw::HWDialect>();
target.addLegalDialect<llhd::LLHDDialect>();
target.addLegalDialect<comb::CombDialect>();

addGenericLegality<cf::CondBranchOp>(target);
addGenericLegality<cf::BranchOp>(target);
addGenericLegality<scf::IfOp>(target);
addGenericLegality<scf::YieldOp>(target);
addGenericLegality<func::CallOp>(target);
addGenericLegality<func::ReturnOp>(target);
addGenericLegality<UnrealizedConversionCastOp>(target);

target.addLegalOp<debug::ScopeOp>();
addGenericLegality<debug::ArrayOp>(target);
addGenericLegality<debug::StructOp>(target);
addGenericLegality<debug::VariableOp>(target);

target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
return hasMooreType(block.getArguments());
});
auto resultsConverted = !hasMooreType(op.getResultTypes());
return argsConverted && resultsConverted;
});

target.addDynamicallyLegalOp<hw::HWModuleOp>([](hw::HWModuleOp op) {
return !hasMooreType(op.getInputTypes()) &&
!hasMooreType(op.getOutputTypes()) &&
!hasMooreType(op.getBody().getArgumentTypes());
});
target.addDynamicallyLegalOp<
cf::CondBranchOp, cf::BranchOp, scf::IfOp, scf::YieldOp, func::CallOp,
func::ReturnOp, UnrealizedConversionCastOp, hw::OutputOp, hw::InstanceOp,
debug::ArrayOp, debug::StructOp, debug::VariableOp>(
[&](Operation *op) { return converter.isLegal(op); });

target.addDynamicallyLegalOp<hw::InstanceOp>([](hw::InstanceOp op) {
return !hasMooreType(op.getInputs()) && !hasMooreType(op.getResults());
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType()) &&
converter.isLegal(&op.getFunctionBody());
});

target.addDynamicallyLegalOp<hw::OutputOp>(
[](hw::OutputOp op) { return !hasMooreType(op.getOutputs()); });
target.addDynamicallyLegalOp<hw::HWModuleOp>([&](hw::HWModuleOp op) {
return converter.isSignatureLegal(op.getModuleType().getFuncType()) &&
converter.isLegal(&op.getBody());
});
}

static void populateTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion([&](IntType type) {
return IntegerType::get(type.getContext(), type.getWidth());
});

typeConverter.addConversion([&](ArrayType type) {
return hw::ArrayType::get(typeConverter.convertType(type.getElementType()),
type.getSize());
typeConverter.addConversion([&](ArrayType type) -> std::optional<Type> {
if (auto elementType = typeConverter.convertType(type.getElementType()))
return hw::ArrayType::get(elementType, type.getSize());
return {};
});

typeConverter.addConversion([&](StructType type) {
typeConverter.addConversion([&](StructType type) -> std::optional<Type> {
SmallVector<hw::StructType::FieldInfo> fields;
for (auto field : type.getMembers()) {
hw::StructType::FieldInfo info;
info.type = typeConverter.convertType(field.type);
if (!info.type)
return {};
info.name = field.name;
fields.push_back(info);
}
Expand All @@ -1283,20 +1288,22 @@ static void populateTypeConversion(TypeConverter &typeConverter) {
// differences though they look similarily. The packed data structure is
// contiguous in memory but another is opposite. The differences will affect
// data layout and granularity of event tracking in simulation.
typeConverter.addConversion([&](UnpackedStructType type) {
SmallVector<hw::StructType::FieldInfo> fields;
for (auto field : type.getMembers()) {
hw::StructType::FieldInfo info;
info.type = typeConverter.convertType(field.type);
info.name = field.name;
fields.push_back(info);
}
return hw::StructType::get(type.getContext(), fields);
});
typeConverter.addConversion(
[&](UnpackedStructType type) -> std::optional<Type> {
SmallVector<hw::StructType::FieldInfo> fields;
for (auto field : type.getMembers()) {
hw::StructType::FieldInfo info;
info.type = typeConverter.convertType(field.type);
if (!info.type)
return {};
info.name = field.name;
fields.push_back(info);
}
return hw::StructType::get(type.getContext(), fields);
});

typeConverter.addConversion([&](RefType type) -> std::optional<Type> {
auto innerType = typeConverter.convertType(type.getNestedType());
if (innerType)
if (auto innerType = typeConverter.convertType(type.getNestedType()))
return hw::InOutType::get(innerType);
return {};
});
Expand All @@ -1307,6 +1314,31 @@ static void populateTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion([](debug::ScopeType type) { return type; });
typeConverter.addConversion([](debug::StructType type) { return type; });

typeConverter.addConversion([&](hw::InOutType type) -> std::optional<Type> {
if (auto innerType = typeConverter.convertType(type.getElementType()))
return hw::InOutType::get(innerType);
return {};
});

typeConverter.addConversion([&](hw::ArrayType type) -> std::optional<Type> {
if (auto elementType = typeConverter.convertType(type.getElementType()))
return hw::ArrayType::get(elementType, type.getNumElements());
return {};
});

typeConverter.addConversion([&](hw::StructType type) -> std::optional<Type> {
SmallVector<hw::StructType::FieldInfo> fields;
for (auto field : type.getElements()) {
hw::StructType::FieldInfo info;
info.type = typeConverter.convertType(field.type);
if (!info.type)
return {};
info.name = field.name;
fields.push_back(info);
}
return hw::StructType::get(type.getContext(), fields);
});

typeConverter.addTargetMaterialization(
[&](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs,
Expand Down Expand Up @@ -1403,8 +1435,8 @@ static void populateOpConversion(RewritePatternSet &patterns,
InPlaceOpConversion<debug::VariableOp>
>(typeConverter, context);
// clang-format on
mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
typeConverter);

hw::populateHWModuleLikeTypeConversionPattern(
hw::HWModuleOp::getOperationName(), patterns, typeConverter);
Expand Down Expand Up @@ -1437,8 +1469,8 @@ void MooreToCorePass::runOnOperation() {
ConversionTarget target(context);
TypeConverter typeConverter;
RewritePatternSet patterns(&context);
populateLegality(target);
populateTypeConversion(typeConverter);
populateLegality(target, typeConverter);
populateOpConversion(patterns, typeConverter);

if (failed(applyFullConversion(module, target, std::move(patterns))))
Expand Down
12 changes: 6 additions & 6 deletions test/Conversion/MooreToCore/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,8 @@ func.func @Expressions(%arg0: !moore.i1, %arg1: !moore.l1, %arg2: !moore.i6, %ar
moore.wildcard_eq %arg0, %arg0 : !moore.i1 -> !moore.i1
moore.wildcard_ne %arg0, %arg0 : !moore.i1 -> !moore.i1

// CHECK-NEXT: [[RES:%.+]] = scf.if %arg0 -> (i6) {
// CHECK-NEXT: scf.yield %arg2 : i6
// CHECK-NEXT: } else {
// CHECK-NEXT: [[TMP:%.+]] = hw.constant 19 : i6
// CHECK-NEXT: scf.yield [[TMP]] : i6
// CHECK-NEXT: }
// CHECK-NEXT: [[TMP:%.+]] = hw.constant 19 : i6
// CHECK-NEXT: [[RES:%.+]] = comb.mux %arg0, %arg2, [[TMP]]
// CHECK-NEXT: comb.parity [[RES]] : i6
%k0 = moore.conditional %arg0 : i1 -> i6 {
moore.yield %arg2 : i6
Expand All @@ -275,6 +271,8 @@ func.func @Expressions(%arg0: !moore.i1, %arg1: !moore.l1, %arg2: !moore.i6, %ar

// CHECK-NEXT: [[RES:%.+]] = scf.if %arg1 -> (i6) {
// CHECK-NEXT: [[TMP:%.+]] = hw.constant 0 : i6
// CHECK: %var_l6 = llhd.sig
// CHECK: llhd.drv %var_l6, [[TMP]] after
// CHECK-NEXT: scf.yield [[TMP]] : i6
// CHECK-NEXT: } else {
// CHECK-NEXT: [[TMP:%.+]] = hw.constant 19 : i6
Expand All @@ -283,6 +281,8 @@ func.func @Expressions(%arg0: !moore.i1, %arg1: !moore.l1, %arg2: !moore.i6, %ar
// CHECK-NEXT: comb.parity [[RES]] : i6
%k1 = moore.conditional %arg1 : l1 -> l6 {
%0 = moore.constant bXXXXXX : l6
%var_l6 = moore.variable : <l6>
moore.blocking_assign %var_l6, %0 : l6
moore.yield %0 : l6
} {
%0 = moore.constant 19 : l6
Expand Down
Loading