From a50b047d8a825b8e9e962ae8a3a7c256ae410c9e Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Thu, 29 Feb 2024 18:04:32 -0800 Subject: [PATCH] [FIRRTL] Handle reference ports when Classes dedup. In https://github.com/llvm/circt/pull/6582, initial support for classes and objects was added in Dedup. However, since classes and objects are new constructs, not every possibility was handled in the initial implementation. This specifically handles the case where references to objects are passed through class ports, and the class type of such objects has changed because their classes deduped. In the final fixup pass through the instance graph, if we find a class, we check if it has any reference ports that need to be updated, and if so update them. When this happens, we also update an objects of the class to reflect the newly updated class type. Fixes https://github.com/llvm/circt/issues/6603. --- lib/Dialect/FIRRTL/Transforms/Dedup.cpp | 84 ++++++++++++++++++++++++- test/firtool/classes-dedupe.fir | 29 +++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp index 84777ec5ba6c..f6e3341d155c 100644 --- a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp +++ b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp @@ -1391,6 +1391,72 @@ struct Deduper { // Fixup //===----------------------------------------------------------------------===// +/// This fixes up ClassLikes with ClassType ports, when the classes have +/// deduped. For each ClassType port, if the object reference being assigned is +/// a different type, update the port type. Returns true if the ClassOp was +/// updated and the associated ObjectOps should be updated. +bool fixupClassOp(ClassOp classOp) { + // New port type attributes, if necessary. + SmallVector newPortTypes; + bool anyDifferences = false; + + // Check each port. + for (size_t i = 0, e = classOp.getNumPorts(); i < e; ++i) { + // Check if this port is a ClassType. If not, save the original type + // attribute in case we need to update port types. + auto portClassType = dyn_cast(classOp.getPortType(i)); + if (!portClassType) { + newPortTypes.push_back(classOp.getPortTypeAttr(i)); + continue; + } + + // Check if this port is assigned a reference of a different ClassType. + Type newPortClassType; + BlockArgument portArg = classOp.getArgument(i); + for (auto &use : portArg.getUses()) { + if (auto propassign = dyn_cast(use.getOwner())) { + Type sourceType = propassign.getSrc().getType(); + if (propassign.getDest() == use.get() && sourceType != portClassType) { + // Double check that all references are the same new type. + if (newPortClassType) { + assert(newPortClassType == sourceType && + "expected all references to be of the same type"); + continue; + } + + newPortClassType = sourceType; + } + } + } + + // If there was no difference, save the original type attribute in case we + // need to update port types and move along. + if (!newPortClassType) { + newPortTypes.push_back(classOp.getPortTypeAttr(i)); + continue; + } + + // The port type changed, so update the block argument, save the new port + // type attribute, and indicate there was a difference. + classOp.getArgument(i).setType(newPortClassType); + newPortTypes.push_back(TypeAttr::get(newPortClassType)); + anyDifferences = true; + } + + // If necessary, update port types. + if (anyDifferences) + classOp.setPortTypes(newPortTypes); + + return anyDifferences; +} + +/// This fixes up ObjectOps when the signature of their ClassOp changes. This +/// amounts to updating the ObjectOp result type to match the newly updated +/// ClassOp type. +void fixupObjectOp(ObjectOp objectOp, ClassType newClassType) { + objectOp.getResult().setType(newClassType); +} + /// This fixes up connects when the field names of a bundle type changes. It /// finds all fields which were previously bulk connected and legalizes it /// into a connect for each field. @@ -1423,9 +1489,25 @@ void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) { void fixupAllModules(InstanceGraph &instanceGraph) { for (auto *node : instanceGraph) { auto module = cast(*node->getModule()); + + // Handle class declarations here. + bool shouldFixupObjects = false; + auto classOp = dyn_cast(module.getOperation()); + if (classOp) + shouldFixupObjects = fixupClassOp(classOp); + for (auto *instRec : node->uses()) { + // Handle object instantiations here. + if (classOp) { + if (shouldFixupObjects) { + fixupObjectOp(instRec->getInstance(), + classOp.getInstanceType()); + } + continue; + } + auto inst = instRec->getInstance(); - // Only handle module instantiations for now. + // Only handle module instantiations here. if (!inst) continue; ImplicitLocOpBuilder builder(inst.getLoc(), inst->getContext()); diff --git a/test/firtool/classes-dedupe.fir b/test/firtool/classes-dedupe.fir index ed4f38897978..5b0bbe78a155 100644 --- a/test/firtool/classes-dedupe.fir +++ b/test/firtool/classes-dedupe.fir @@ -81,12 +81,22 @@ circuit Test : %[[ output out_bar : Integer propassign out_bar, Integer(1) + class Foo_3 : + output out_baz : Integer + propassign out_baz, Integer(1) + + class Foo_4 : + output out_baz : Integer + propassign out_baz, Integer(1) + ; CHECK-LABEL: om.class @OM_1(%basepath: !om.basepath) class OM_1 : output out_1 : Path output out_2 : Path output out_foo_1 : Inst output out_foo_2 : Inst + output out_foo_3 : Inst + output out_foo_4 : Inst object foo_1 of Foo_1 propassign out_foo_1, foo_1 @@ -94,17 +104,30 @@ circuit Test : %[[ object foo_2 of Foo_2 propassign out_foo_2, foo_2 + ; CHECK: [[FOO_3:%.+]] = om.object @Foo_3 + object foo_3 of Foo_3 + propassign out_foo_3, foo_3 + + ; CHECK: [[FOO_4:%.+]] = om.object @Foo_3 + object foo_4 of Foo_4 + propassign out_foo_4, foo_4 + ; CHECK: om.path_create reference %basepath [[NLA1]] propassign out_1, path("OMReferenceTarget:~Test|CPU_1>out") ; CHECK: om.path_create reference %basepath [[NLA2]] propassign out_2, path("OMReferenceTarget:~Test|CPU_1/fetch_1:Fetch_1>foo") + ; CHECK: om.class.field @out_foo_3, [[FOO_3]] + ; CHECK: om.class.field @out_foo_4, [[FOO_4]] + ; CHECK-NOT: OM_2 class OM_2 : output out_1 : Path output out_2 : Path output out_foo_1 : Inst output out_foo_2 : Inst + output out_foo_3 : Inst + output out_foo_4 : Inst object foo_1 of Foo_1 propassign out_foo_1, foo_1 @@ -112,5 +135,11 @@ circuit Test : %[[ object foo_2 of Foo_2 propassign out_foo_2, foo_2 + object foo_3 of Foo_3 + propassign out_foo_3, foo_3 + + object foo_4 of Foo_4 + propassign out_foo_4, foo_4 + propassign out_1, path("OMReferenceTarget:~Test|CPU_2>out") propassign out_2, path("OMReferenceTarget:~Test|CPU_2/fetch_1:Fetch_2>foo")