Skip to content

Commit

Permalink
Add new general convolution op and layout attribute, convert stablehl…
Browse files Browse the repository at this point in the history
…o convolution to it (#1104)

Add pass to transform eligible ConvolutionOp --> Conv2dOp in TTIR

Introduce reshapes in ttir->ttnn conversion of conv2d

Use output memory config specification feature recently added to ttnn::conv2d

create dedicated silicon tests
  • Loading branch information
LPanosTT authored Oct 31, 2024
1 parent 677be1d commit c4dd537
Show file tree
Hide file tree
Showing 17 changed files with 792 additions and 48 deletions.
6 changes: 6 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ add_mlir_dialect(TTIROps ttir)
add_mlir_doc(TTIRBase TTIRDialect src/autogen/md/Dialect/ -gen-dialect-doc)
add_mlir_doc(TTIROps TTIROp src/autogen/md/Dialect/ -gen-op-doc)

set(LLVM_TARGET_DEFINITIONS TTIROpsAttrs.td)
mlir_tablegen(TTIROpsAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(TTIROpsAttrs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TTIROpsAttrsIncGen)
add_dependencies(mlir-headers TTIROpsAttrsIncGen)

set(LLVM_TARGET_DEFINITIONS TTIROpsInterfaces.td)
mlir_tablegen(TTIROpsInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(TTIROpsInterfaces.cpp.inc -gen-op-interface-defs)
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIRBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def TTIR_Dialect : Dialect {
or dialects that are actually supported by a consuming backend.
}];
let cppNamespace = "::mlir::tt::ttir";

let useDefaultAttributePrinterParser = 1;
let dependentDialects = [
"::mlir::arith::ArithDialect",
"::mlir::func::FuncDialect",
Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#include "TTIROpsInterfaces.h"

#define GET_ATTRDEF_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.h.inc"

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h.inc"

Expand Down
40 changes: 40 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

include "ttmlir/Dialect/TT/IR/TTOpsTypes.td"
include "ttmlir/Dialect/TTIR/IR/TTIRBase.td"
include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td"
include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td"
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down Expand Up @@ -598,6 +599,45 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
let summary = "Generalized convolution op.";
let description = [{
Applies a convolution of the rhs with the lhs.

This operation captures convolutions of all dimensionality as well
as deconvolution/conv transpose.
}];

let arguments = (ins
AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
// Default value: one for each of the spatial dimension.
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getConvolutionLayout().getInputSpatialDimensions().size(), 1)">:$window_strides,
// Default value: two zeros for each of the spatial dimension.
DefaultValuedOptionalAttr<I64ElementsAttr, "SmallVector<int64_t>(getConvolutionLayout().getInputSpatialDimensions().size()*2, 0)">:$padding,
// Default value: one for each of the spatial dimension.
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getConvolutionLayout().getInputSpatialDimensions().size(), 1)">:$input_dilation,
// Default value: one for each of the spatial dimension.
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getConvolutionLayout().getInputSpatialDimensions().size(), 1)">:$weight_dilation,
// Default value: false for each of the spatial dimension.
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "SmallVector<bool>(getConvolutionLayout().getInputSpatialDimensions().size(), false)">:$window_reversal,
TTIR_ConvolutionLayoutAttr:$convolution_layout,
ConfinedAttr<I64Attr, [IntPositive]>:$feature_group_count,
ConfinedAttr<I64Attr, [IntPositive]>:$batch_group_count,
TT_OperandConstraintArrayAttr:$operand_constraints
);

let results = (outs AnyRankedTensor);
let hasVerifier = 1;

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
}


def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
let description = [{
Expand Down
44 changes: 44 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTIR_ATTRS_TD
#define TTMLIR_TTIR_ATTRS_TD

include "mlir/IR/AttrTypeBase.td"
include "ttmlir/Dialect/TTIR/IR/TTIRBase.td"

def TTIR_ConvolutionLayoutAttr : AttrDef<TTIR_Dialect, "ConvolutionLayout", [], "::mlir::Attribute"> {
let mnemonic = "convolution_layout";
let summary = "Structure of dimension information for convolution op";
let description = [{
Holds the layout information for the input activation, weights, and output.
}];
let parameters = (ins
"int64_t":$inputBatchDimension,
"int64_t":$inputFeatureDimension,
ArrayRefParameter<"int64_t">:$inputSpatialDimensions,

"int64_t":$kernelOutputFeatureDimension,
"int64_t":$kernelInputFeatureDimension,
ArrayRefParameter<"int64_t">:$kernelSpatialDimensions,

"int64_t":$outputBatchDimension,
"int64_t":$outputFeatureDimension,
ArrayRefParameter<"int64_t">:$outputSpatialDimensions
);

let assemblyFormat = [{
`input_batch` `=` $inputBatchDimension `,`
`input_feature` `=` $inputFeatureDimension`,`
`input_spatial_dimensions` `=` custom<DimensionList>($inputSpatialDimensions) `,`
`kernel_output_feature` `=` $kernelOutputFeatureDimension `,`
`kernel_input_feature` `=` $kernelInputFeatureDimension `,`
`kernel_spatial_dimensions` `=` custom<DimensionList>($kernelSpatialDimensions) `,`
`output_batch` `=` $outputBatchDimension `,`
`output_feature` `=` $outputFeatureDimension `,`
`output_spatial_dimensions` `=` custom<DimensionList>($outputSpatialDimensions)
}];
}

#endif // TTMLIR_TTIR_ATTRS_TD
52 changes: 17 additions & 35 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,46 +373,28 @@ class StableHLOToTTIRConvolutionOpConversionPattern
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

std::vector<int64_t> strides =
adaptor.getWindowStrides().value_or(ArrayRef<int64_t>({1, 1})).vec();
IntegerAttr stride_height_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(strides[0]));
IntegerAttr stride_width_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(strides[1]));

std::vector<int64_t> dilation =
adaptor.getLhsDilation().value_or(ArrayRef<int64_t>({1, 1})).vec();

IntegerAttr dilation_height_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(dilation[0]));
IntegerAttr dilation_width_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(dilation[1]));

IntegerAttr groups_attr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(adaptor.getFeatureGroupCount()));

std::vector<int32_t> padding;
if (!adaptor.getPadding().has_value()) {
padding = {0, 0, 0, 0};
} else {
for (auto iter = adaptor.getPadding()->value_begin<int64_t>();
iter < adaptor.getPadding()->value_end<int64_t>(); iter++) {
padding.push_back(static_cast<int32_t>(*iter));
}
}

rewriter.replaceOpWithNewOp<mlir::tt::ttir::Conv2dOp>(
auto dimNums = adaptor.getDimensionNumbers();
rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConvolutionOp>(
srcOp, outputType, adaptor.getLhs(), adaptor.getRhs(),
mlir::Value(nullptr), outputTensor, stride_height_attr,
stride_width_attr, dilation_height_attr, dilation_width_attr,
groups_attr, rewriter.getSI32IntegerAttr(padding[0]),
rewriter.getSI32IntegerAttr(padding[1]),
rewriter.getSI32IntegerAttr(padding[2]),
rewriter.getSI32IntegerAttr(padding[3]),
mlir::Value(nullptr), outputTensor, adaptor.getWindowStridesAttr(),
adaptor.getPaddingAttr(), adaptor.getLhsDilationAttr(),
adaptor.getRhsDilationAttr(), adaptor.getWindowReversalAttr(),
mlir::tt::ttir::ConvolutionLayoutAttr::get(
getContext(), dimNums.getInputBatchDimension(),
dimNums.getInputFeatureDimension(),
dimNums.getInputSpatialDimensions(),
dimNums.getKernelOutputFeatureDimension(),
dimNums.getKernelInputFeatureDimension(),
dimNums.getKernelSpatialDimensions(),
dimNums.getOutputBatchDimension(),
dimNums.getOutputFeatureDimension(),
dimNums.getOutputSpatialDimensions()),
adaptor.getFeatureGroupCountAttr(), adaptor.getBatchGroupCountAttr(),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

return success();
}
};
Expand Down
Loading

0 comments on commit c4dd537

Please sign in to comment.