Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InferReadWrite] Add heuristic to infer unmasked memory #6790

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,103 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
return {};
}

void handleCatPrimOp(CatPrimOp defOp, SmallVectorImpl<Value> &bits) {

long lastSize = 0;
// Cat the bits of both the operands.
for (auto operand : defOp->getOperands()) {
SmallVectorImpl<Value> &opBits = valueBitsSrc[operand];
size_t s =
getBitWidth(type_cast<FIRRTLBaseType>(operand.getType())).value();
assert(opBits.size() == s);
for (long i = lastSize, e = lastSize + s; i != e; ++i)
bits[i] = opBits[i - lastSize];
lastSize = s;
}
}

void handleBitsPrimOp(BitsPrimOp bitsPrim, SmallVectorImpl<Value> &bits) {

SmallVectorImpl<Value> &opBits = valueBitsSrc[bitsPrim.getInput()];
for (size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0;
srcIndex <= e; ++srcIndex, ++i)
bits[i] = opBits[srcIndex];
}

// Try to extract the value assigned to each bit of `val`. This is a heuristic
// to determine if each bit of the `val` is assigned the same value.
// Common pattern that this heuristic detects,
// mask = {{w1,w1},{w2,w2}}}
// w1 = w[0]
// w2 = w[0]
bool areBitsDrivenBySameSource(Value val) {
SmallVector<Value> stack;
stack.push_back(val);

while (!stack.empty()) {
auto val = stack.back();
if (valueBitsSrc.contains(val)) {
stack.pop_back();
continue;
}

auto size = getBitWidth(type_cast<FIRRTLBaseType>(val.getType()));
// Cannot analyze aggregate types.
if (!size.has_value())
return false;

auto bitsSize = size.value();
if (auto *defOp = val.getDefiningOp()) {
if (isa<CatPrimOp>(defOp)) {
bool operandsDone = true;
// If the value is a cat of other values, compute the bits of the
// operands.
for (auto operand : defOp->getOperands()) {
if (valueBitsSrc.contains(operand))
continue;
stack.push_back(operand);
operandsDone = false;
}
if (!operandsDone)
continue;

valueBitsSrc[val].resize_for_overwrite(bitsSize);
handleCatPrimOp(cast<CatPrimOp>(defOp), valueBitsSrc[val]);
} else if (auto bitsPrim = dyn_cast<BitsPrimOp>(defOp)) {
auto input = bitsPrim.getInput();
if (!valueBitsSrc.contains(input)) {
stack.push_back(input);
continue;
}
valueBitsSrc[val].resize_for_overwrite(bitsSize);
handleBitsPrimOp(bitsPrim, valueBitsSrc[val]);
} else if (auto constOp = dyn_cast<ConstantOp>(defOp)) {
auto constVal = constOp.getValue();
valueBitsSrc[val].resize_for_overwrite(bitsSize);
if (constVal.isAllOnes() || constVal.isZero()) {
for (auto &b : valueBitsSrc[val])
b = constOp;
} else
return false;
} else if (auto wireOp = dyn_cast<WireOp>(defOp)) {
if (bitsSize != 1)
return false;
valueBitsSrc[val].resize_for_overwrite(bitsSize);
if (auto src = getConnectSrc(wireOp.getResult())) {
valueBitsSrc[val][0] = src;
} else
valueBitsSrc[val][0] = wireOp.getResult();
} else
return false;
} else
return false;
stack.pop_back();
}
if (!valueBitsSrc.contains(val))
return false;
return llvm::all_equal(valueBitsSrc[val]);
}

// Remove redundant dependence of wmode on the enable signal. wmode can assume
// the enable signal be true.
void simplifyWmode(MemOp &memOp) {
Expand Down Expand Up @@ -404,11 +501,11 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
if (sf.getResult().getType().getBitWidthOrSentinel() == 1)
continue;
// Check what is the mask field directly connected to.
// If, a constant 1, then we can replace with unMasked memory.
// If we can infer that all the bits of the mask are always assigned
// the same value, then the memory is unmasked.
if (auto maskVal = getConnectSrc(sf))
if (auto constVal = dyn_cast<ConstantOp>(maskVal.getDefiningOp()))
if (constVal.getValue().isAllOnes())
isMasked = false;
if (areBitsDrivenBySameSource(maskVal))
isMasked = false;
}
}
}
Expand Down Expand Up @@ -467,6 +564,10 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
memOp = newMem;
}
}

// Record of what are the source values that drive each bit of a value. Used
// to check if each bit of a value is being driven by the same source.
llvm::DenseMap<Value, SmallVector<Value>> valueBitsSrc;
};
} // end anonymous namespace

Expand Down
20 changes: 20 additions & 0 deletions test/Dialect/FIRRTL/inferRW.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -302,5 +302,25 @@ firrtl.circuit "TLRAM" {
// CHECK: %[[v7:.+]] = firrtl.mux(%[[c1_ui1]], %rwPort_isWrite, %c0_ui1)
firrtl.strictconnect %mem_rwPort_readData_rw_wmode, %18 : !firrtl.uint<1>
}

// CHECK: firrtl.module @InferUnmasked
firrtl.module @InferUnmasked(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>) attributes {convention = #firrtl<convention scalarized>} {
%readwritePortA_isWrite_2 = firrtl.wire {name = "readwritePortA_isWrite"} : !firrtl.uint<1>
%syncreadmem_singleport_readwritePortA_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "syncreadmem_singleport", portNames = ["readwritePortA_readData_rw"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<6>, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>>
// CHECK: %syncreadmem_singleport_readwritePortA_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "syncreadmem_singleport", portNames = ["readwritePortA_readData_rw"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<6>, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<1>>
%syncreadmem_singleport_readwritePortA_readData_rw_wmask_x = firrtl.wire : !firrtl.uint<1>
%syncreadmem_singleport_readwritePortA_readData_rw_wmask_y = firrtl.wire : !firrtl.uint<1>
%9 = firrtl.subfield %syncreadmem_singleport_readwritePortA_readData_rw[wmask] : !firrtl.bundle<addr: uint<6>, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>>
%10 = firrtl.cat %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y, %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2>
%11 = firrtl.bits %10 0 to 0 : (!firrtl.uint<2>) -> !firrtl.uint<1>
%12 = firrtl.cat %11, %11 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2>
%13 = firrtl.cat %11, %12 : (!firrtl.uint<1>, !firrtl.uint<2>) -> !firrtl.uint<3>
%14 = firrtl.bits %10 1 to 1 : (!firrtl.uint<2>) -> !firrtl.uint<1>
%15 = firrtl.cat %14, %13 : (!firrtl.uint<1>, !firrtl.uint<3>) -> !firrtl.uint<4>
%16 = firrtl.cat %14, %15 : (!firrtl.uint<1>, !firrtl.uint<4>) -> !firrtl.uint<5>
firrtl.strictconnect %9, %16 : !firrtl.uint<5>
firrtl.connect %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x, %readwritePortA_isWrite_2 : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y, %readwritePortA_isWrite_2 : !firrtl.uint<1>, !firrtl.uint<1>
}
}

Loading