From 11e5e7dd6453c23c01a86f591f6bf10043ca1f0c Mon Sep 17 00:00:00 2001 From: James Molloy Date: Mon, 18 Nov 2024 01:56:22 -0800 Subject: [PATCH] [xls][mlir] Fix bug in array-to-bits, add support for ArrayConcatOp and speed up We weren't declaring ArrayIndexOp and friends illegal, so the rewriter sometimes would not apply our rewrite patterns. Also noticed that ArrayConcatOp was missing. Also noticed it was super slow on a large module due to no parallelization, so added that. Then noticed that XLS verification was super slow (dominant in multithreading as we verify per thread), so fixed that. PiperOrigin-RevId: 697543700 --- xls/contrib/mlir/IR/xls_ops.cc | 19 +++++++----- xls/contrib/mlir/IR/xls_ops.td | 12 ++++---- xls/contrib/mlir/testdata/array_to_bits.mlir | 11 +++++++ xls/contrib/mlir/transforms/array_to_bits.cc | 32 ++++++++++++++++++-- 4 files changed, 58 insertions(+), 16 deletions(-) diff --git a/xls/contrib/mlir/IR/xls_ops.cc b/xls/contrib/mlir/IR/xls_ops.cc index ddc2c49952..90f6bec7c6 100644 --- a/xls/contrib/mlir/IR/xls_ops.cc +++ b/xls/contrib/mlir/IR/xls_ops.cc @@ -583,8 +583,9 @@ LogicalResult SpawnOp::verifySymbolUses(SymbolTableCollection& symbolTable) { namespace { LogicalResult verifyChannelUsingOp(Operation* op, SymbolRefAttr channelAttr, - Type elementType) { - auto chanOp = SymbolTable::lookupNearestSymbolFrom(op, channelAttr); + Type elementType, + SymbolTableCollection& symbolTable) { + auto chanOp = symbolTable.lookupNearestSymbolFrom(op, channelAttr); if (!chanOp) { return op->emitOpError("channel symbol not found: ") << channelAttr; } @@ -607,19 +608,21 @@ LogicalResult verifyStructuredChannelUsingOp(Operation* op, Value channel, } // namespace -LogicalResult BlockingReceiveOp::verify() { +LogicalResult BlockingReceiveOp::verifySymbolUses( + SymbolTableCollection& symbolTable) { return verifyChannelUsingOp(getOperation(), getChannelAttr(), - getResult().getType()); + getResult().getType(), symbolTable); } -LogicalResult NonblockingReceiveOp::verify() { +LogicalResult NonblockingReceiveOp::verifySymbolUses( + SymbolTableCollection& symbolTable) { return verifyChannelUsingOp(getOperation(), getChannelAttr(), - getResult().getType()); + getResult().getType(), symbolTable); } -LogicalResult SendOp::verify() { +LogicalResult SendOp::verifySymbolUses(SymbolTableCollection& symbolTable) { return verifyChannelUsingOp(getOperation(), getChannelAttr(), - getData().getType()); + getData().getType(), symbolTable); } LogicalResult SBlockingReceiveOp::verify() { diff --git a/xls/contrib/mlir/IR/xls_ops.td b/xls/contrib/mlir/IR/xls_ops.td index 9dbb4187f7..ed481b35c8 100644 --- a/xls/contrib/mlir/IR/xls_ops.td +++ b/xls/contrib/mlir/IR/xls_ops.td @@ -418,7 +418,8 @@ def Xls_SignExtOp : Xls_UnaryOp<"sign_ext", [Pure, SameOperandsAndResultShape]> def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [ TensorArrayTypeFungible, PredicatableOpInterface, - CallOpInterface]> { + CallOpInterface, + DeclareOpInterfaceMethods]> { let summary = "Receives a data value from a specified channel."; let description = [{ Receives a data value from a specified channel. The type of the data value @@ -440,7 +441,6 @@ def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [ let assemblyFormat = [{ $tkn `,` ($predicate^ `,`)? $channel attr-dict `:` type($result) }]; - let hasVerifier = 1; let extraClassDeclaration = [{ ::mlir::Value getCondition() { return getPredicate(); @@ -464,7 +464,8 @@ def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [ def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [ TensorArrayTypeFungible, PredicatableOpInterface, - CallOpInterface]> { + CallOpInterface, + DeclareOpInterfaceMethods]> { let summary = "Receives a data value from a specified channel."; let description = [{ Receives a data value from a specified channel. The type of the data value @@ -488,7 +489,6 @@ def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [ let assemblyFormat = [{ $tkn `,` ($predicate^ `,`)? $channel attr-dict `:` type($result) }]; - let hasVerifier = 1; let extraClassDeclaration = [{ ::mlir::Value getCondition() { return getPredicate(); @@ -514,7 +514,8 @@ def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [ def Xls_SendOp : Xls_Op<"send", [ TensorArrayTypeFungible, PredicatableOpInterface, - CallOpInterface]> { + CallOpInterface, + DeclareOpInterfaceMethods]> { let summary = "Sends data to a specified channel."; let description = [{ Sends data to a specified channel. The type of the data values is determined @@ -535,7 +536,6 @@ def Xls_SendOp : Xls_Op<"send", [ let assemblyFormat = [{ $tkn `,` $data `,` ($predicate^ `,`)? $channel attr-dict `:` type($data) }]; - let hasVerifier = 1; let extraClassDeclaration = [{ ::mlir::Value getCondition() { return getPredicate(); diff --git a/xls/contrib/mlir/testdata/array_to_bits.mlir b/xls/contrib/mlir/testdata/array_to_bits.mlir index 81c30cf329..c89d3f9e95 100644 --- a/xls/contrib/mlir/testdata/array_to_bits.mlir +++ b/xls/contrib/mlir/testdata/array_to_bits.mlir @@ -263,3 +263,14 @@ func.func @call_dslx(%arg0: !xls.array<4 x i32>) -> !xls.array<4 x f32> attribut %8 = xls.array %1, %3, %5, %7 : (f32, f32, f32, f32) -> !xls.array<4 x f32> return %8 : !xls.array<4 x f32> } + +// CHECK-LABEL: func.func @array_concat( +// CHECK-SAME: %[[VAL_0:.*]]: i64, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i128 attributes {xls = true} { +// CHECK: %[[VAL_2:.*]] = xls.concat %[[VAL_0]], %[[VAL_1]] : (i64, i64) -> i128 +// CHECK: return %[[VAL_2]] : i128 +// CHECK: } +func.func @array_concat(%arg0: !xls.array<2 x i32>, %arg1: !xls.array<2 x i32>) -> !xls.array<4 x i32> attributes {xls = true} { + %0 = "xls.array_concat"(%arg0, %arg1) : (!xls.array<2 x i32>, !xls.array<2 x i32>) -> !xls.array<4 x i32> + return %0 : !xls.array<4 x i32> +} diff --git a/xls/contrib/mlir/transforms/array_to_bits.cc b/xls/contrib/mlir/transforms/array_to_bits.cc index 0aee7959aa..44217dabf4 100644 --- a/xls/contrib/mlir/transforms/array_to_bits.cc +++ b/xls/contrib/mlir/transforms/array_to_bits.cc @@ -28,6 +28,7 @@ #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" #include "mlir/include/mlir/IR/PatternMatch.h" +#include "mlir/include/mlir/IR/Threading.h" #include "mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/include/mlir/IR/ValueRange.h" #include "mlir/include/mlir/IR/Visitors.h" @@ -296,6 +297,24 @@ class LegalizeArrayZeroPattern : public OpConversionPattern { } }; +class LegalizeArrayConcatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ArrayConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + (void)adaptor; + SmallVector operands = + CoerceFloats(adaptor.getOperands(), rewriter, op); + if (operands.empty() && !adaptor.getOperands().empty()) { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), operands); + return success(); + } +}; + class ArrayToBitsPass : public impl::ArrayToBitsPassBase { public: void runOnOperation() override { @@ -308,7 +327,9 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase { return all_of(op->getOperandTypes(), is_legal) && all_of(op->getResultTypes(), is_legal); }); - target.addIllegalOp(); + target.addIllegalOp(); RewritePatternSet chanPatterns(&getContext()); chanPatterns.add(typeConverter, &getContext()); FrozenRewritePatternSet frozenChanPatterns(std::move(chanPatterns)); @@ -323,6 +344,7 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase { LegalizeArrayIndexPattern, LegalizeArrayIndexStaticPattern, LegalizeArrayZeroPattern, + LegalizeArrayConcatPattern, LegalizeGenericOpPattern // clang-format on >(typeConverter, &getContext()); @@ -338,10 +360,11 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase { }); FrozenRewritePatternSet frozenRegionPatterns(std::move(regionPatterns)); + SmallVector regions; getOperation()->walk([&](Operation* op) { if (auto interface = dyn_cast(op)) { if (interface.isSupportedRegion()) { - runOnOperation(interface, target, frozenRegionPatterns); + regions.push_back(interface); return WalkResult::skip(); } } else if (auto chanOp = dyn_cast(op)) { @@ -350,6 +373,11 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase { } return WalkResult::advance(); }); + + mlir::parallelForEach( + &getContext(), regions, [&](XlsRegionOpInterface interface) { + runOnOperation(interface, target, frozenRegionPatterns); + }); } void runOnOperation(ChanOp operation, ConversionTarget& target,