From 9aa32b4df0fa5987321c5f0ba0bfc3b4a1337e0e Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Wed, 6 Mar 2024 09:13:49 -0500 Subject: [PATCH 1/4] Add heuristic to infer unmasked memory --- .../FIRRTL/Transforms/InferReadWrite.cpp | 75 +++++++++++++++++-- test/Dialect/FIRRTL/inferRW.mlir | 20 +++++ 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp index 470caddf8a5c..4e3c55a92094 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp @@ -314,6 +314,69 @@ struct InferReadWritePass : public InferReadWriteBase { return {}; } + // Try to extract the value assigned to each bit of `val`. This is a heuristic + // to determine if each bit of the `val` is assigned the same value. + // NOLINTNEXTLINE(misc-no-recursion) + bool getBitsOfVal(Value val, SmallVector &bits) { + + auto size = getBitWidth(type_cast(val.getType())); + if (!size.has_value()) + return false; + bits.resize_for_overwrite(size.value()); + if (auto *defOp = val.getDefiningOp()) { + + if (isa(defOp)) { + long lastSize = 0; + for (auto operand : defOp->getOperands()) { + SmallVector opBits; + if (!getBitsOfVal(operand, opBits)) + return false; + auto s = + getBitWidth(type_cast(operand.getType())).value(); + for (long i = lastSize, e = lastSize + s; i != e; ++i) + bits[i] = opBits[i - lastSize]; + lastSize = s; + } + return true; + } + if (auto bitsPrim = dyn_cast(defOp)) { + SmallVector opBits; + if (!getBitsOfVal(bitsPrim.getInput(), opBits)) + return false; + for (size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0; + srcIndex <= e; ++srcIndex, ++i) + bits[i] = opBits[srcIndex]; + return true; + } + if (auto constOp = dyn_cast(defOp)) { + auto constVal = constOp.getValue(); + if (constVal.isAllOnes() || constVal.isZero()) { + for (auto &b : bits) + b = constOp; + return true; + } + return false; + } + if (auto wireOp = dyn_cast(defOp)) { + SmallVector wireBits; + if (auto src = getConnectSrc(wireOp.getResult())) { + if (getBitsOfVal(src, wireBits)) + bits = wireBits; + + } else { + for (auto &b : bits) + b = wireOp.getResult(); + } + return true; + } + } + if (size.value() == 1) { + bits[0] = val; + return true; + } + return false; + } + // Remove redundant dependence of wmode on the enable signal. wmode can assume // the enable signal be true. void simplifyWmode(MemOp &memOp) { @@ -404,11 +467,13 @@ struct InferReadWritePass : public InferReadWriteBase { if (sf.getResult().getType().getBitWidthOrSentinel() == 1) continue; // Check what is the mask field directly connected to. - // If, a constant 1, then we can replace with unMasked memory. - if (auto maskVal = getConnectSrc(sf)) - if (auto constVal = dyn_cast(maskVal.getDefiningOp())) - if (constVal.getValue().isAllOnes()) - isMasked = false; + // If we can infer that all the bits of the mask are always assigned + // the same value, then the memory is unmasked. + if (auto maskVal = getConnectSrc(sf)) { + SmallVector bits; + if (getBitsOfVal(maskVal, bits)) + isMasked = !llvm::all_equal(bits); + } } } } diff --git a/test/Dialect/FIRRTL/inferRW.mlir b/test/Dialect/FIRRTL/inferRW.mlir index 1c3e959f74db..61a2d479e418 100644 --- a/test/Dialect/FIRRTL/inferRW.mlir +++ b/test/Dialect/FIRRTL/inferRW.mlir @@ -302,5 +302,25 @@ firrtl.circuit "TLRAM" { // CHECK: %[[v7:.+]] = firrtl.mux(%[[c1_ui1]], %rwPort_isWrite, %c0_ui1) firrtl.strictconnect %mem_rwPort_readData_rw_wmode, %18 : !firrtl.uint<1> } + + // CHECK: firrtl.module @InferUnmasked + firrtl.module @InferUnmasked(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>) attributes {convention = #firrtl} { + %readwritePortA_isWrite_2 = firrtl.wire {name = "readwritePortA_isWrite"} : !firrtl.uint<1> + %syncreadmem_singleport_readwritePortA_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "syncreadmem_singleport", portNames = ["readwritePortA_readData_rw"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + // CHECK: %syncreadmem_singleport_readwritePortA_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "syncreadmem_singleport", portNames = ["readwritePortA_readData_rw"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<1>> + %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x = firrtl.wire : !firrtl.uint<1> + %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y = firrtl.wire : !firrtl.uint<1> + %9 = firrtl.subfield %syncreadmem_singleport_readwritePortA_readData_rw[wmask] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + %10 = firrtl.cat %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y, %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + %11 = firrtl.bits %10 0 to 0 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %12 = firrtl.cat %11, %11 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + %13 = firrtl.cat %11, %12 : (!firrtl.uint<1>, !firrtl.uint<2>) -> !firrtl.uint<3> + %14 = firrtl.bits %10 1 to 1 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %15 = firrtl.cat %14, %13 : (!firrtl.uint<1>, !firrtl.uint<3>) -> !firrtl.uint<4> + %16 = firrtl.cat %14, %15 : (!firrtl.uint<1>, !firrtl.uint<4>) -> !firrtl.uint<5> + firrtl.strictconnect %9, %16 : !firrtl.uint<5> + firrtl.connect %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x, %readwritePortA_isWrite_2 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y, %readwritePortA_isWrite_2 : !firrtl.uint<1>, !firrtl.uint<1> + } } From fb69e7656ba77814dee6d59bbae74f1725fea55f Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Mon, 11 Mar 2024 01:34:53 -0400 Subject: [PATCH 2/4] WIP: Iterative --- lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp index 4e3c55a92094..7c6cf623d9d7 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp @@ -18,6 +18,7 @@ #include "circt/Dialect/FIRRTL/Passes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -318,6 +319,13 @@ struct InferReadWritePass : public InferReadWriteBase { // to determine if each bit of the `val` is assigned the same value. // NOLINTNEXTLINE(misc-no-recursion) bool getBitsOfVal(Value val, SmallVector &bits) { + SmallVector stack; + llvm::DenseMap> valueBitsSrc; + stack.push_back(val); + + while (!stack.empty()) { + auto currentNode = stack.back(); + } auto size = getBitWidth(type_cast(val.getType())); if (!size.has_value()) From 3f9a6f4e4a3514e791d5dd0e1f2a7f6478564ea9 Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Tue, 12 Mar 2024 17:06:46 -0400 Subject: [PATCH 3/4] Implement an iterative analysis --- .../FIRRTL/Transforms/InferReadWrite.cpp | 144 ++++++++++-------- 1 file changed, 82 insertions(+), 62 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp index 7c6cf623d9d7..5d94c075129e 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp @@ -315,74 +315,92 @@ struct InferReadWritePass : public InferReadWriteBase { return {}; } + void handleCatPrimOp(CatPrimOp defOp, SmallVectorImpl &bits) { + + long lastSize = 0; + // Cat the bits of both the operands. + for (auto operand : defOp->getOperands()) { + SmallVectorImpl &opBits = valueBitsSrc[operand]; + size_t s = + getBitWidth(type_cast(operand.getType())).value(); + assert(opBits.size() == s); + for (long i = lastSize, e = lastSize + s; i != e; ++i) + bits[i] = opBits[i - lastSize]; + lastSize = s; + } + } + + void handleBitsPrimOp(BitsPrimOp bitsPrim, SmallVectorImpl &bits) { + + SmallVectorImpl &opBits = valueBitsSrc[bitsPrim.getInput()]; + for (size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0; + srcIndex <= e; ++srcIndex, ++i) + bits[i] = opBits[srcIndex]; + } + // Try to extract the value assigned to each bit of `val`. This is a heuristic // to determine if each bit of the `val` is assigned the same value. - // NOLINTNEXTLINE(misc-no-recursion) - bool getBitsOfVal(Value val, SmallVector &bits) { + bool areBitsDrivenBySameSource(Value val) { SmallVector stack; - llvm::DenseMap> valueBitsSrc; stack.push_back(val); while (!stack.empty()) { - auto currentNode = stack.back(); - } + auto val = stack.back(); + auto size = getBitWidth(type_cast(val.getType())); + // Cannot analyze aggregate types. + if (!size.has_value()) + return false; - auto size = getBitWidth(type_cast(val.getType())); - if (!size.has_value()) - return false; - bits.resize_for_overwrite(size.value()); - if (auto *defOp = val.getDefiningOp()) { - - if (isa(defOp)) { - long lastSize = 0; - for (auto operand : defOp->getOperands()) { - SmallVector opBits; - if (!getBitsOfVal(operand, opBits)) + auto bitsSize = size.value(); + if (auto *defOp = val.getDefiningOp()) { + if (isa(defOp)) { + bool operandsDone = true; + // If the value is a cat of other values, compute the bits of the + // operands. + for (auto operand : defOp->getOperands()) { + if (valueBitsSrc.contains(operand)) + continue; + stack.push_back(operand); + operandsDone = false; + } + if (!operandsDone) + continue; + + valueBitsSrc[val].resize_for_overwrite(bitsSize); + handleCatPrimOp(cast(defOp), valueBitsSrc[val]); + } else if (auto bitsPrim = dyn_cast(defOp)) { + auto input = bitsPrim.getInput(); + if (!valueBitsSrc.contains(input)) { + stack.push_back(input); + continue; + } + valueBitsSrc[val].resize_for_overwrite(bitsSize); + handleBitsPrimOp(bitsPrim, valueBitsSrc[val]); + } else if (auto constOp = dyn_cast(defOp)) { + auto constVal = constOp.getValue(); + valueBitsSrc[val].resize_for_overwrite(bitsSize); + if (constVal.isAllOnes() || constVal.isZero()) { + for (auto &b : valueBitsSrc[val]) + b = constOp; + } else return false; - auto s = - getBitWidth(type_cast(operand.getType())).value(); - for (long i = lastSize, e = lastSize + s; i != e; ++i) - bits[i] = opBits[i - lastSize]; - lastSize = s; - } - return true; - } - if (auto bitsPrim = dyn_cast(defOp)) { - SmallVector opBits; - if (!getBitsOfVal(bitsPrim.getInput(), opBits)) + } else if (auto wireOp = dyn_cast(defOp)) { + if (bitsSize != 1) + return false; + valueBitsSrc[val].resize_for_overwrite(bitsSize); + if (auto src = getConnectSrc(wireOp.getResult())) { + valueBitsSrc[val][0] = src; + } else + valueBitsSrc[val][0] = wireOp.getResult(); + } else return false; - for (size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0; - srcIndex <= e; ++srcIndex, ++i) - bits[i] = opBits[srcIndex]; - return true; - } - if (auto constOp = dyn_cast(defOp)) { - auto constVal = constOp.getValue(); - if (constVal.isAllOnes() || constVal.isZero()) { - for (auto &b : bits) - b = constOp; - return true; - } + } else return false; - } - if (auto wireOp = dyn_cast(defOp)) { - SmallVector wireBits; - if (auto src = getConnectSrc(wireOp.getResult())) { - if (getBitsOfVal(src, wireBits)) - bits = wireBits; - - } else { - for (auto &b : bits) - b = wireOp.getResult(); - } - return true; - } - } - if (size.value() == 1) { - bits[0] = val; - return true; + stack.pop_back(); } - return false; + if (!valueBitsSrc.contains(val)) + return false; + return llvm::all_equal(valueBitsSrc[val]); } // Remove redundant dependence of wmode on the enable signal. wmode can assume @@ -477,11 +495,9 @@ struct InferReadWritePass : public InferReadWriteBase { // Check what is the mask field directly connected to. // If we can infer that all the bits of the mask are always assigned // the same value, then the memory is unmasked. - if (auto maskVal = getConnectSrc(sf)) { - SmallVector bits; - if (getBitsOfVal(maskVal, bits)) - isMasked = !llvm::all_equal(bits); - } + if (auto maskVal = getConnectSrc(sf)) + if (areBitsDrivenBySameSource(maskVal)) + isMasked = false; } } } @@ -540,6 +556,10 @@ struct InferReadWritePass : public InferReadWriteBase { memOp = newMem; } } + + // Record of what are the source values that drive each bit of a value. Used + // to check if each bit of a value is being driven by the same source. + llvm::DenseMap> valueBitsSrc; }; } // end anonymous namespace From f6cc014b3e981571d0b1faa86291b7352dc440ee Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Tue, 12 Mar 2024 17:38:55 -0400 Subject: [PATCH 4/4] Update comments --- lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp index 5d94c075129e..03eaae0da3cd 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp @@ -18,7 +18,6 @@ #include "circt/Dialect/FIRRTL/Passes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/APSInt.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -340,12 +339,21 @@ struct InferReadWritePass : public InferReadWriteBase { // Try to extract the value assigned to each bit of `val`. This is a heuristic // to determine if each bit of the `val` is assigned the same value. + // Common pattern that this heuristic detects, + // mask = {{w1,w1},{w2,w2}}} + // w1 = w[0] + // w2 = w[0] bool areBitsDrivenBySameSource(Value val) { SmallVector stack; stack.push_back(val); while (!stack.empty()) { auto val = stack.back(); + if (valueBitsSrc.contains(val)) { + stack.pop_back(); + continue; + } + auto size = getBitWidth(type_cast(val.getType())); // Cannot analyze aggregate types. if (!size.has_value())