Skip to content

Commit

Permalink
[xls][mlir] Fix bug in array-to-bits, add support for ArrayConcatOp a…
Browse files Browse the repository at this point in the history
…nd speed up

We weren't declaring ArrayIndexOp and friends illegal, so the rewriter sometimes would not apply our rewrite patterns.

Also noticed that ArrayConcatOp was missing.

Also noticed it was super slow on a large module due to no parallelization, so added that.

Then noticed that XLS verification was super slow (dominant in multithreading as we verify per thread),
so fixed that.

PiperOrigin-RevId: 697543700
  • Loading branch information
James Molloy authored and copybara-github committed Nov 18, 2024
1 parent 64e6d11 commit 11e5e7d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 16 deletions.
19 changes: 11 additions & 8 deletions xls/contrib/mlir/IR/xls_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,9 @@ LogicalResult SpawnOp::verifySymbolUses(SymbolTableCollection& symbolTable) {

namespace {
LogicalResult verifyChannelUsingOp(Operation* op, SymbolRefAttr channelAttr,
Type elementType) {
auto chanOp = SymbolTable::lookupNearestSymbolFrom<ChanOp>(op, channelAttr);
Type elementType,
SymbolTableCollection& symbolTable) {
auto chanOp = symbolTable.lookupNearestSymbolFrom<ChanOp>(op, channelAttr);
if (!chanOp) {
return op->emitOpError("channel symbol not found: ") << channelAttr;
}
Expand All @@ -607,19 +608,21 @@ LogicalResult verifyStructuredChannelUsingOp(Operation* op, Value channel,

} // namespace

LogicalResult BlockingReceiveOp::verify() {
LogicalResult BlockingReceiveOp::verifySymbolUses(
SymbolTableCollection& symbolTable) {
return verifyChannelUsingOp(getOperation(), getChannelAttr(),
getResult().getType());
getResult().getType(), symbolTable);
}

LogicalResult NonblockingReceiveOp::verify() {
LogicalResult NonblockingReceiveOp::verifySymbolUses(
SymbolTableCollection& symbolTable) {
return verifyChannelUsingOp(getOperation(), getChannelAttr(),
getResult().getType());
getResult().getType(), symbolTable);
}

LogicalResult SendOp::verify() {
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection& symbolTable) {
return verifyChannelUsingOp(getOperation(), getChannelAttr(),
getData().getType());
getData().getType(), symbolTable);
}

LogicalResult SBlockingReceiveOp::verify() {
Expand Down
12 changes: 6 additions & 6 deletions xls/contrib/mlir/IR/xls_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ def Xls_SignExtOp : Xls_UnaryOp<"sign_ext", [Pure, SameOperandsAndResultShape]>
def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [
TensorArrayTypeFungible,
PredicatableOpInterface,
CallOpInterface]> {
CallOpInterface,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Receives a data value from a specified channel.";
let description = [{
Receives a data value from a specified channel. The type of the data value
Expand All @@ -440,7 +441,6 @@ def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [
let assemblyFormat = [{
$tkn `,` ($predicate^ `,`)? $channel attr-dict `:` type($result)
}];
let hasVerifier = 1;
let extraClassDeclaration = [{
::mlir::Value getCondition() {
return getPredicate();
Expand All @@ -464,7 +464,8 @@ def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [
def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [
TensorArrayTypeFungible,
PredicatableOpInterface,
CallOpInterface]> {
CallOpInterface,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Receives a data value from a specified channel.";
let description = [{
Receives a data value from a specified channel. The type of the data value
Expand All @@ -488,7 +489,6 @@ def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [
let assemblyFormat = [{
$tkn `,` ($predicate^ `,`)? $channel attr-dict `:` type($result)
}];
let hasVerifier = 1;
let extraClassDeclaration = [{
::mlir::Value getCondition() {
return getPredicate();
Expand All @@ -514,7 +514,8 @@ def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [
def Xls_SendOp : Xls_Op<"send", [
TensorArrayTypeFungible,
PredicatableOpInterface,
CallOpInterface]> {
CallOpInterface,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Sends data to a specified channel.";
let description = [{
Sends data to a specified channel. The type of the data values is determined
Expand All @@ -535,7 +536,6 @@ def Xls_SendOp : Xls_Op<"send", [
let assemblyFormat = [{
$tkn `,` $data `,` ($predicate^ `,`)? $channel attr-dict `:` type($data)
}];
let hasVerifier = 1;
let extraClassDeclaration = [{
::mlir::Value getCondition() {
return getPredicate();
Expand Down
11 changes: 11 additions & 0 deletions xls/contrib/mlir/testdata/array_to_bits.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,14 @@ func.func @call_dslx(%arg0: !xls.array<4 x i32>) -> !xls.array<4 x f32> attribut
%8 = xls.array %1, %3, %5, %7 : (f32, f32, f32, f32) -> !xls.array<4 x f32>
return %8 : !xls.array<4 x f32>
}

// CHECK-LABEL: func.func @array_concat(
// CHECK-SAME: %[[VAL_0:.*]]: i64,
// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i128 attributes {xls = true} {
// CHECK: %[[VAL_2:.*]] = xls.concat %[[VAL_0]], %[[VAL_1]] : (i64, i64) -> i128
// CHECK: return %[[VAL_2]] : i128
// CHECK: }
func.func @array_concat(%arg0: !xls.array<2 x i32>, %arg1: !xls.array<2 x i32>) -> !xls.array<4 x i32> attributes {xls = true} {
%0 = "xls.array_concat"(%arg0, %arg1) : (!xls.array<2 x i32>, !xls.array<2 x i32>) -> !xls.array<4 x i32>
return %0 : !xls.array<4 x i32>
}
32 changes: 30 additions & 2 deletions xls/contrib/mlir/transforms/array_to_bits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mlir/include/mlir/IR/OpDefinition.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "mlir/include/mlir/IR/PatternMatch.h"
#include "mlir/include/mlir/IR/Threading.h"
#include "mlir/include/mlir/IR/TypeUtilities.h"
#include "mlir/include/mlir/IR/ValueRange.h"
#include "mlir/include/mlir/IR/Visitors.h"
Expand Down Expand Up @@ -296,6 +297,24 @@ class LegalizeArrayZeroPattern : public OpConversionPattern<ArrayZeroOp> {
}
};

class LegalizeArrayConcatPattern : public OpConversionPattern<ArrayConcatOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
ArrayConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
(void)adaptor;
SmallVector<Value> operands =
CoerceFloats(adaptor.getOperands(), rewriter, op);
if (operands.empty() && !adaptor.getOperands().empty()) {
return failure();
}
rewriter.replaceOpWithNewOp<ConcatOp>(
op, typeConverter->convertType(op.getType()), operands);
return success();
}
};

class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
public:
void runOnOperation() override {
Expand All @@ -308,7 +327,9 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
return all_of(op->getOperandTypes(), is_legal) &&
all_of(op->getResultTypes(), is_legal);
});
target.addIllegalOp<VectorizedCallOp>();
target.addIllegalOp<VectorizedCallOp, ArrayOp, ArrayUpdateOp, ArraySliceOp,
ArrayIndexOp, ArrayIndexStaticOp, ArrayZeroOp,
ArrayConcatOp>();
RewritePatternSet chanPatterns(&getContext());
chanPatterns.add<LegalizeChanOpPattern>(typeConverter, &getContext());
FrozenRewritePatternSet frozenChanPatterns(std::move(chanPatterns));
Expand All @@ -323,6 +344,7 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
LegalizeArrayIndexPattern,
LegalizeArrayIndexStaticPattern,
LegalizeArrayZeroPattern,
LegalizeArrayConcatPattern,
LegalizeGenericOpPattern
// clang-format on
>(typeConverter, &getContext());
Expand All @@ -338,10 +360,11 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
});
FrozenRewritePatternSet frozenRegionPatterns(std::move(regionPatterns));

SmallVector<XlsRegionOpInterface> regions;
getOperation()->walk([&](Operation* op) {
if (auto interface = dyn_cast<XlsRegionOpInterface>(op)) {
if (interface.isSupportedRegion()) {
runOnOperation(interface, target, frozenRegionPatterns);
regions.push_back(interface);
return WalkResult::skip();
}
} else if (auto chanOp = dyn_cast<ChanOp>(op)) {
Expand All @@ -350,6 +373,11 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
}
return WalkResult::advance();
});

mlir::parallelForEach(
&getContext(), regions, [&](XlsRegionOpInterface interface) {
runOnOperation(interface, target, frozenRegionPatterns);
});
}

void runOnOperation(ChanOp operation, ConversionTarget& target,
Expand Down

0 comments on commit 11e5e7d

Please sign in to comment.