diff --git a/docs/src/dialects-overview.md b/docs/src/dialects-overview.md index e886fb90c1..0dbf5fbed1 100644 --- a/docs/src/dialects-overview.md +++ b/docs/src/dialects-overview.md @@ -3,7 +3,7 @@ Here is a brief overview of the dialects in the project, please refer to the individual dialect documentation for more details.: -- `tt`: Common types such as, `tt.tile`, `tt.layout`, `tt.grid`, etc. and enums such as, data formats, memory spaces, iterator types etc. +- `tt`: Common types such as, `tt.tile`, `tt.metal_layout`, `tt.grid`, etc. and enums such as, data formats, memory spaces, iterator types etc. - `ttir`: A high level dialect that models the tensor compute graph on tenstorrent devices. Accepts `tosa` and `linalg` input. - `ttir.generic`: Generically describe compute work. - `ttir.to_layout`: Convert between different tensor memory layouts and transfer between different memory spaces. diff --git a/docs/src/specs/device.md b/docs/src/specs/device.md index ae72fe638c..64bc91cfa9 100644 --- a/docs/src/specs/device.md +++ b/docs/src/specs/device.md @@ -135,7 +135,7 @@ the logical device grid: ```mlir tensor<16x3x64x128xf32, - #tt.layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), + #tt.metal_layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), undef, <2x2x4>, memref<8x3x1x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -170,7 +170,7 @@ the logical device grid: ```mlir tensor<256x1024xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x16>, memref<2x2x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -205,7 +205,7 @@ We can consider the following tensor to map onto this grid: ```mlir tensor<64x256x1024xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x4x16>, memref<32x2x2x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> diff --git a/docs/src/specs/tensor-layout.md b/docs/src/specs/tensor-layout.md index d523f51ed2..52c6931895 100644 --- a/docs/src/specs/tensor-layout.md +++ b/docs/src/specs/tensor-layout.md @@ -33,7 +33,7 @@ been used by the TT dialect to encode the tensor's layout. This looks like: ```mlir tensor<2x3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, @@ -76,7 +76,7 @@ topics: ### Dimension Collapsing -Probably the most important concept in `tt.layout` is dimension collapsing. +Probably the most important concept in `tt.metal_layout` is dimension collapsing. This is captured by the affine map `linear` property which provides a mapping from tensor dim space to a reduced physical dimensional space. This single-handedly touches on most of the tensor layout goals mentioned at the @@ -106,7 +106,7 @@ to get our remapped offset: This remapped offset `(262, 100)` corresponds to the row and column index of the collapsed physical memory. -By default, the dim range `[0, -1)` is collapsed, but the `tt.layout` contructor +By default, the dim range `[0, -1)` is collapsed, but the `tt.metal_layout` contructor can actually take a programmable range called `collapseIntervals`. `collapseIntervals` is a list of pairs, where each pair is a dim range interval, left inclusive, right exclusive. Let's consider a few examples: @@ -137,7 +137,7 @@ Let's consider the original example again, but on a larger grid than `1x1`, say ```mlir tensor<2x3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, @@ -173,7 +173,7 @@ Here's a few more example mlir snippets: ```mlir tensor<8x300xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x2>, memref<8x150xf32, #tt.memory_space> @@ -181,7 +181,7 @@ tensor<8x300xf32, > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), undef, <2x1>, memref<384x32xf32, #tt.memory_space> @@ -189,7 +189,7 @@ tensor<8x96x32xf32, > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), undef, <2x1x2>, memref<384x96x16xf32, #tt.memory_space> @@ -197,7 +197,7 @@ tensor<8x96x32xf32, > tensor<5x3x2x2x7x32x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3, d4, d5, d6) -> (d0 * 2688 + d1 * 896 + d2 * 448 + d3 * 224 + d4 * 32 + d5, d4, d5, d6), undef, @@ -226,7 +226,7 @@ A tilized tensor is one with a memref that has a tile element type. Given some tensor with scalar layout: ```mlir tensor<3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 64 + d1, d2), undef, <3x2>, @@ -238,7 +238,7 @@ tensor<3x64x128xf32, After tilizing we'll have: ```mlir tensor<3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 64 + d1, d2), undef, <3x2>, @@ -256,7 +256,7 @@ intact. Padding can be a bit of an overloaded term, but in this context it refers to an out of bounds area in the physical memory allocation that has no real tensor data in it. The contents of this area is tracked by `oob_val` and the padding -area can be automatically derived from the attributes of `tt.layout`. +area can be automatically derived from the attributes of `tt.metal_layout`. Padding is a necessary evil that arises when a tensor is not evenly divisible by a grid shape or tile shape. It can also arise due to minimum Noc addressing @@ -265,7 +265,7 @@ requirements. Example of non-divisible grid: ```mlir tensor<53x63xf32, - #tt.layout< + #tt.metal_layout< (d0, d1) -> (d0, d1), undef, <3x2>, @@ -284,7 +284,7 @@ cores and 1 scalar column of padding on the last column of cores. Taking the above example a step further, we could tilize it: ```mlir tensor<53x63xf32, - #tt.layout< + #tt.metal_layout< (d0, d1) -> (d0, d1), undef, <3x2>, @@ -308,7 +308,7 @@ stride between dimensions. Consider tensor (w/ batch dim `2`): ```mlir tensor<2x8x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 8 + d1, d2), undef, <1x2>, @@ -356,7 +356,7 @@ consider the following example with a 3d grid and `collapseIntervals=[(1, -1)]`. ```mlir tensor<2x3x64x128xf32, - #tt.layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), + #tt.metal_layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), undef, <2x2x4>, memref<1x3x1x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -387,7 +387,7 @@ under the same grid primitive that also divides tensor rows and columns. ## Concerns -- `tt.layout` is deliberately flexible and tries to capture as many problematic +- `tt.metal_layout` is deliberately flexible and tries to capture as many problematic use-cases we've ran into in the past in a single, succinct representation. This flexibility will need to be further constrained by backends to avoid unsupported programming of this attribute. diff --git a/include/ttmlir-c/TTAttrs.h b/include/ttmlir-c/TTAttrs.h index 2e164ac132..263cd1d8e4 100644 --- a/include/ttmlir-c/TTAttrs.h +++ b/include/ttmlir-c/TTAttrs.h @@ -50,9 +50,9 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet( size_t chipCoordsSize, MlirAttribute *chipChannels, size_t chipChannelsSize); -MLIR_CAPI_EXPORTED MlirAttribute -ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, unsigned oobVal, - MlirAttribute grid, MlirType memref, unsigned memLayout); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTMetalLayoutAttrGet( + MlirContext ctx, MlirAffineMap linear, unsigned oobVal, MlirAttribute grid, + MlirType memref, unsigned memLayout); MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTMemorySpaceAttrGet(MlirContext ctx, uint32_t memorySpace); diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index d9ff13164e..600090ccf7 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -214,7 +214,7 @@ def TT_SystemDescAttr : TT_Attr<"SystemDesc", "system_desc"> { }]; } -def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { +def TT_MetalLayoutAttr : TT_Attr<"MetalLayout", "metal_layout"> { let summary = "Tensor layout attribute"; let description = [{ The tensor layout attribute captures how tensor data is sharded across a grid of devices, cores, and @@ -241,7 +241,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { Examples: ```mlir tensor<8x300xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x2>, memref<8x150xf32, #tt.memory_space> @@ -249,7 +249,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), undef, <2x1>, memref<384x32xf32, #tt.memory_space> @@ -257,7 +257,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), undef, <2x1x2>, memref<384x96x16xf32, #tt.memory_space> @@ -265,7 +265,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<5x3x2x2x7x32x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3, d4, d5, d6) -> (d0 * 2688 + d1 * 896 + d2 * 448 + d3 * 224 + d4 * 32 + d5, d4, d5, d6), undef, @@ -284,7 +284,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { let assemblyFormat = "`<` $linear`,` $oob_val`,` $grid`,` $memref (`,` $mem_layout^)? `>`"; let extraClassDeclaration = [{ - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, MemorySpace memorySpace = MemorySpace::System, @@ -292,28 +292,28 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { ArrayRef> collapseIntervals = {{0, -1}}, OOBVal oobVal = OOBVal::Undef, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace = MemorySpace::System, GridAttr grid = {}, ArrayRef> collapseIntervals = {{0, -1}}, OOBVal oobVal = OOBVal::Undef, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace, GridAttr grid, Type elementType, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - LayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); - LayoutAttr withGrid(::mlir::MLIRContext *context, + MetalLayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); + MetalLayoutAttr withGrid(::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); - LayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType); - LayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace); - LayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); - LayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape); + MetalLayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType); + MetalLayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace); + MetalLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); + MetalLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape); uint64_t getMemrefSizeBytes() const; MemorySpace getMemorySpace() const; @@ -400,7 +400,7 @@ def TT_DeviceAttr : TT_Attr<"Device", "device", []> { // - DeviceL1: This ends up being exactly the shard size // - DeviceDRAM: Is more nuanced because the whole tensor size gets paged and interleaved between all dram channels, // due to paging and rounding the footprint ends up being close to: the_whole_tensor / num_dram_channels - uint64_t getLayoutSizeBytes(ArrayRef tensorShape, LayoutAttr layout, MemorySpace memorySpace) const; + uint64_t getLayoutSizeBytes(ArrayRef tensorShape, MetalLayoutAttr layout, MemorySpace memorySpace) const; // Returns the footprint size in bytes of the tensor distributed across the given memory space. // Forwards to getLayoutSizeBytes, see comment there for more info. diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index aeb2de1aed..ea84df06b8 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -114,8 +114,8 @@ def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpI - Some combination of the above ```llvm - #layout = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>> - #layout1 = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>> + #layout = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>> + #layout1 = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>> %1 = "ttir.to_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1> ``` }]; diff --git a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h index ac23b9bb0d..d5be2bb97c 100644 --- a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h +++ b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h @@ -18,8 +18,9 @@ namespace mlir::tt { flatbuffers::Offset<::tt::target::LayoutDesc> -layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, LayoutAttr attr, - ArrayRef logicalShape, DeviceAttr deviceAttr); +metalLayoutAttrToFlatbuffer(FlatbufferObjectCache &cache, MetalLayoutAttr attr, + ArrayRef logicalShape, + DeviceAttr deviceAttr); flatbuffers::Offset<::tt::target::LayoutDesc> ttnnLayoutAttrToFlatbuffer( FlatbufferObjectCache &cache, ttnn::TTNNLayoutAttr attr, @@ -438,9 +439,9 @@ toFlatbuffer(FlatbufferObjectCache &cache, ElementsAttr elementsAttr) { inline flatbuffers::Offset<::tt::target::LayoutDesc> encodingToFlatbuffer(FlatbufferObjectCache &cache, Attribute attr, ArrayRef logicalShape, DeviceAttr deviceAttr) { - if (isa(attr)) { - return layoutAttrToFlatbuffer(cache, cast(attr), logicalShape, - deviceAttr); + if (isa(attr)) { + return metalLayoutAttrToFlatbuffer(cache, cast(attr), + logicalShape, deviceAttr); } assert(isa(attr) && "unsupported layout attr"); diff --git a/lib/CAPI/TTAttrs.cpp b/lib/CAPI/TTAttrs.cpp index 196dc09f47..c329f41d56 100644 --- a/lib/CAPI/TTAttrs.cpp +++ b/lib/CAPI/TTAttrs.cpp @@ -119,15 +119,15 @@ MlirAttribute ttmlirTTSystemDescAttrGet( chipCapabilitiesUnwrapped, chipCoordsUnwrapped, chipChannelsUnwrapped)); } -MlirAttribute ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, - unsigned oobVal, MlirAttribute grid, - MlirType memref, unsigned memLayout) { +MlirAttribute ttmlirTTMetalLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, + unsigned oobVal, MlirAttribute grid, + MlirType memref, unsigned memLayout) { mlir::AffineMap affineMap = mlir::AffineMap::getFromOpaquePointer(linear.ptr); - return wrap(LayoutAttr::get(unwrap(ctx), affineMap, - static_cast(oobVal), - mlir::cast(unwrap(grid)), - mlir::cast(unwrap(memref)), - static_cast(memLayout))); + return wrap(MetalLayoutAttr::get(unwrap(ctx), affineMap, + static_cast(oobVal), + mlir::cast(unwrap(grid)), + mlir::cast(unwrap(memref)), + static_cast(memLayout))); } MlirAttribute ttmlirTTMemorySpaceAttrGet(MlirContext ctx, diff --git a/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp b/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp index a3bbddc1da..60c0328197 100644 --- a/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp +++ b/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp @@ -199,8 +199,8 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { LogicalResult relayout(ttir::ToLayoutOp op, PatternRewriter &rewriter) const { auto inputTy = mlir::cast(op.getInput().getType()); auto outputTy = mlir::cast(op.getType()); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); tt::DeviceAttr device = op.getDevice(); assert(device); tt::SystemDescAttr systemDesc = op.getSystemDesc(); @@ -342,8 +342,8 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { LogicalResult reformat(ttir::ToLayoutOp op, PatternRewriter &rewriter) const { auto inputTy = mlir::cast(op.getInput().getType()); auto outputTy = mlir::cast(op.getType()); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); bool shouldTilize = not inputLayout.isTiled() && outputLayout.isTiled(); bool shouldUntilize = inputLayout.isTiled() && not outputLayout.isTiled(); assert(shouldTilize ^ shouldUntilize); @@ -448,10 +448,10 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { return failure(); } assert(inputTy.getShape() == outputTy.getShape()); - assert(mlir::isa(inputTy.getEncoding())); - assert(mlir::isa(outputTy.getEncoding())); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + assert(mlir::isa(inputTy.getEncoding())); + assert(mlir::isa(outputTy.getEncoding())); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); auto components = op.compoundComponents(); bool isCompound = (static_cast(components.isLayoutChange) + @@ -1308,10 +1308,10 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { SmallVector> calculateDataMovement(ArrayAttr iteratorTypes, const RankedTensorType &src, const RankedTensorType &dst, DeviceAttr device) const { - auto srcLayout = mlir::cast(src.getEncoding()); + auto srcLayout = mlir::cast(src.getEncoding()); assert(srcLayout.isTiled()); - auto dstLayout = mlir::cast(dst.getEncoding()); + auto dstLayout = mlir::cast(dst.getEncoding()); assert(dstLayout.isTiled()); assert(iteratorTypes.size() >= 2 && "Expected at least 2 iterator types"); diff --git a/lib/Dialect/TT/IR/TTDialect.cpp b/lib/Dialect/TT/IR/TTDialect.cpp index 6f629d6977..1ac8a22239 100644 --- a/lib/Dialect/TT/IR/TTDialect.cpp +++ b/lib/Dialect/TT/IR/TTDialect.cpp @@ -13,13 +13,13 @@ using namespace mlir; using namespace mlir::tt; -// This is needed to hoist tt.layout attributes as named attributes declared at -// the module level. +// This is needed to hoist tt.metal_layout attributes as named attributes +// declared at the module level. struct TTOpAsmDialectInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (llvm::isa(attr)) { + if (llvm::isa(attr)) { os << "layout"; return AliasResult::OverridableAlias; } diff --git a/lib/Dialect/TT/IR/TTOpsTypes.cpp b/lib/Dialect/TT/IR/TTOpsTypes.cpp index bbdd4e2590..12166e4433 100644 --- a/lib/Dialect/TT/IR/TTOpsTypes.cpp +++ b/lib/Dialect/TT/IR/TTOpsTypes.cpp @@ -466,7 +466,7 @@ calculateLogicalShardShape(mlir::ArrayRef tensorShape, return shardShape; } -LayoutAttr LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::get( ::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, MemorySpace memorySpace, GridAttr grid, ArrayRef> collapseIntervals, @@ -483,7 +483,7 @@ LayoutAttr LayoutAttr::get( return get(context, linear, oobVal, grid, memref, memLayout); } -LayoutAttr LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::get( ::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace, GridAttr grid, ArrayRef> collapseIntervals, @@ -493,9 +493,11 @@ LayoutAttr LayoutAttr::get( collapseIntervals, oobVal, memLayout); } -LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty, - MemorySpace memorySpace, GridAttr grid, - Type elementType, TensorMemoryLayout memLayout) { +MetalLayoutAttr MetalLayoutAttr::get(::mlir::MLIRContext *context, + RankedTensorType ty, + MemorySpace memorySpace, GridAttr grid, + Type elementType, + TensorMemoryLayout memLayout) { assert(ty); assert(grid); return get(context, ty.getShape(), elementType, memorySpace, grid, {{0, -1}}, @@ -506,7 +508,7 @@ LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty, // compute the physical shape of the tensor, i.e the shape of the tensor // after the dimensions have been collapsed onto a grid. llvm::SmallVector -LayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { +MetalLayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { llvm::SmallVector physicalShape(getGrid().getShape().size()); SmallVector logicalShapeExprs( llvm::map_range(logicalShape, [context = getContext()](std::int64_t e) { @@ -525,7 +527,7 @@ LayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { } llvm::SmallVector -LayoutAttr::getStride(ArrayRef logicalShape) const { +MetalLayoutAttr::getStride(ArrayRef logicalShape) const { llvm::SmallVector stride(logicalShape.size()); @@ -574,7 +576,7 @@ LayoutAttr::getStride(ArrayRef logicalShape) const { } llvm::SmallVector -LayoutAttr::getShardShape(bool convertTileToScalar) const { +MetalLayoutAttr::getShardShape(bool convertTileToScalar) const { SmallVector shardShape(getMemref().getShape()); auto elementType = getElementType(); if (mlir::isa(elementType) && convertTileToScalar) { @@ -583,11 +585,11 @@ LayoutAttr::getShardShape(bool convertTileToScalar) const { return shardShape; } -mlir::Type LayoutAttr::getElementType() const { +mlir::Type MetalLayoutAttr::getElementType() const { return getMemref().getElementType(); } -mlir::Type LayoutAttr::getScalarElementType() const { +mlir::Type MetalLayoutAttr::getScalarElementType() const { auto elementType = getElementType(); if (mlir::isa(elementType)) { return mlir::cast(elementType).getElementType(); @@ -595,33 +597,33 @@ mlir::Type LayoutAttr::getScalarElementType() const { return elementType; } -bool LayoutAttr::hasShardedTensorMemoryLayout() const { +bool MetalLayoutAttr::hasShardedTensorMemoryLayout() const { return (getMemLayout() == TensorMemoryLayout::HeightSharded or getMemLayout() == TensorMemoryLayout::WidthSharded or getMemLayout() == TensorMemoryLayout::BlockSharded); } -bool LayoutAttr::hasInterleavedTensorMemoryLayout() const { +bool MetalLayoutAttr::hasInterleavedTensorMemoryLayout() const { return (getMemLayout() == TensorMemoryLayout::Interleaved); } -bool LayoutAttr::hasShardedL1TensorMemoryLayout() const { +bool MetalLayoutAttr::hasShardedL1TensorMemoryLayout() const { return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and (getMemLayout() == TensorMemoryLayout::HeightSharded or getMemLayout() == TensorMemoryLayout::WidthSharded or getMemLayout() == TensorMemoryLayout::BlockSharded); } -bool LayoutAttr::hasInterleavedL1TensorMemoryLayout() const { +bool MetalLayoutAttr::hasInterleavedL1TensorMemoryLayout() const { return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and (getMemLayout() == TensorMemoryLayout::Interleaved); } -bool LayoutAttr::isTiled() const { +bool MetalLayoutAttr::isTiled() const { return ::mlir::isa<::mlir::tt::TileType>(getElementType()); } -uint64_t LayoutAttr::getElementSizeBytes() const { +uint64_t MetalLayoutAttr::getElementSizeBytes() const { mlir::Type elementType = getElementType(); if (mlir::isa(elementType)) { auto tileType = mlir::cast(elementType); @@ -630,7 +632,7 @@ uint64_t LayoutAttr::getElementSizeBytes() const { return elementType.getIntOrFloatBitWidth() / 8; } -uint64_t LayoutAttr::getMemrefSizeBytes() const { +uint64_t MetalLayoutAttr::getMemrefSizeBytes() const { MemRefType ty = getMemref(); auto shape = ty.getShape(); uint64_t size = getElementSizeBytes(); @@ -638,57 +640,60 @@ uint64_t LayoutAttr::getMemrefSizeBytes() const { std::multiplies()); } -LayoutAttr LayoutAttr::withGrid( +MetalLayoutAttr MetalLayoutAttr::withGrid( ::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals) { return get(context, tensorShape, getElementType(), getMemorySpace(), grid, collapseIntervals, getOobVal(), getMemLayout()); } -LayoutAttr LayoutAttr::withGrid( +MetalLayoutAttr MetalLayoutAttr::withGrid( ::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals) { assert(ty); - return LayoutAttr::withGrid(context, ty.getShape(), grid, collapseIntervals); + return MetalLayoutAttr::withGrid(context, ty.getShape(), grid, + collapseIntervals); } -LayoutAttr LayoutAttr::withElementType(::mlir::MLIRContext *context, - Type elementType) { - return LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::withElementType(::mlir::MLIRContext *context, + Type elementType) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef(context, getShardShape(), elementType, getMemorySpace()), getMemLayout()); } -LayoutAttr LayoutAttr::withMemorySpace(::mlir::MLIRContext *context, - MemorySpace memorySpace) { - return LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::withMemorySpace(::mlir::MLIRContext *context, + MemorySpace memorySpace) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef(context, getShardShape(), getElementType(), memorySpace), getMemLayout()); } -LayoutAttr LayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, - TensorMemoryLayout memLayout) { - return LayoutAttr::get( +MetalLayoutAttr +MetalLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, + TensorMemoryLayout memLayout) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef( context, getShardShape(), getElementType(), getMemorySpace()), memLayout); } -LayoutAttr LayoutAttr::withShardShape(::mlir::MLIRContext *context, - llvm::SmallVector shardShape) { - return LayoutAttr::get( +MetalLayoutAttr +MetalLayoutAttr::withShardShape(::mlir::MLIRContext *context, + llvm::SmallVector shardShape) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef( context, shardShape, getElementType(), getMemorySpace()), getMemLayout()); } -MemorySpace LayoutAttr::getMemorySpace() const { +MemorySpace MetalLayoutAttr::getMemorySpace() const { return mlir::cast(getMemref().getMemorySpace()) .getValue(); } @@ -696,7 +701,7 @@ MemorySpace LayoutAttr::getMemorySpace() const { // Returns shape of the tensor after tilization is applied to the two inner most // dimensions. llvm::SmallVector -LayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { +MetalLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { assert(isTiled() && "Expected a tiled layout"); mlir::AffineMap linear = getLinear(); @@ -716,7 +721,7 @@ LayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { return ttmlir::utils::evalShape(tiled, tensorShape); } -mlir::AffineMap LayoutAttr::getIdentityTileLinearMap() const { +mlir::AffineMap MetalLayoutAttr::getIdentityTileLinearMap() const { assert(isTiled() && "Expected a tiled layout"); return mlir::AffineMap::getMultiDimIdentityMap(getLinear().getNumResults(), @@ -735,7 +740,7 @@ mlir::AffineMap LayoutAttr::getIdentityTileLinearMap() const { // (d0, d1)[2, 3] -> // (0, d0 floordiv 2, d1 floordiv 3, (d0 mod 2) * 3 + d1 mod 3) // -mlir::AffineMap LayoutAttr::replaceMemoryMapSymbolsWithShardShape( +mlir::AffineMap MetalLayoutAttr::replaceMemoryMapSymbolsWithShardShape( AffineMap physicalMemoryMap) const { mlir::SmallVector shardShape = getShardShape(false /*convertTileToScalar*/); @@ -763,8 +768,8 @@ mlir::AffineMap LayoutAttr::replaceMemoryMapSymbolsWithShardShape( // grid. Then it composes the logical grid projection with physical memory // mapping. mlir::AffineMap -LayoutAttr::projectOnto(mlir::AffineMap linearMap, - mlir::AffineMap physicalMemoryMap) const { +MetalLayoutAttr::projectOnto(mlir::AffineMap linearMap, + mlir::AffineMap physicalMemoryMap) const { assert(getGrid().getShape().size() == physicalMemoryMap.getNumDims() && "Layout and device grids must have same number of dimensions"); assert(getLinear().getNumResults() == physicalMemoryMap.getNumDims() && @@ -1013,7 +1018,7 @@ DeviceAttr DeviceAttr::get(::mlir::MLIRContext *context, // Sample the last index in the tensor to get the last addressable element of // the tensor to determine its footprint in memory. uint64_t DeviceAttr::getLayoutSizeBytes(ArrayRef tensorScalarShape, - LayoutAttr layout, + MetalLayoutAttr layout, MemorySpace memorySpace) const { SmallVector shape = layout.isTiled() ? layout.getTiledShape(tensorScalarShape) @@ -1035,9 +1040,9 @@ uint64_t DeviceAttr::getLayoutSizeBytes(ArrayRef tensorScalarShape, uint64_t DeviceAttr::getTensorSizeBytes(RankedTensorType tensorType, MemorySpace memorySpace) const { assert(tensorType.getEncoding()); - return getLayoutSizeBytes(tensorType.getShape(), - mlir::cast(tensorType.getEncoding()), - memorySpace); + return getLayoutSizeBytes( + tensorType.getShape(), + mlir::cast(tensorType.getEncoding()), memorySpace); } ::mlir::LogicalResult diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 3cd28626a4..5e5ab6c579 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -908,9 +908,9 @@ ::mlir::LogicalResult mlir::tt::ttir::ToLayoutOp::verify() { mlir::tt::ttir::ToLayoutOp::CompoundComponents mlir::tt::ttir::ToLayoutOp::compoundComponents() { auto inputLayout = - mlir::cast(getInput().getType().getEncoding()); + mlir::cast(getInput().getType().getEncoding()); auto outputLayout = - mlir::cast(getOutput().getType().getEncoding()); + mlir::cast(getOutput().getType().getEncoding()); bool isLayoutChange = inputLayout.getLinear() != outputLayout.getLinear(); bool isGridChange = inputLayout.getGrid() != outputLayout.getGrid(); bool isShardChange = @@ -1216,7 +1216,7 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { // AllocOp verification ::mlir::LogicalResult mlir::tt::ttir::AllocOp::verify() { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( getResult().getType().getEncoding()); if (not layout) { return emitOpError("Result type missing layout attribute"); diff --git a/lib/Dialect/TTIR/Transforms/Allocate.cpp b/lib/Dialect/TTIR/Transforms/Allocate.cpp index 37e788385c..a643f041c3 100644 --- a/lib/Dialect/TTIR/Transforms/Allocate.cpp +++ b/lib/Dialect/TTIR/Transforms/Allocate.cpp @@ -22,13 +22,13 @@ inline MemorySpace getMemorySpace(MemRefType memref) { return mlir::cast(memref.getMemorySpace()).getValue(); } -inline MemorySpace getMemorySpace(LayoutAttr layout) { +inline MemorySpace getMemorySpace(MetalLayoutAttr layout) { return getMemorySpace(layout.getMemref()); } inline MemorySpace getMemorySpace(RankedTensorType ty) { assert(ty.getEncoding()); - auto layout = mlir::cast(ty.getEncoding()); + auto layout = mlir::cast(ty.getEncoding()); return getMemorySpace(layout); } diff --git a/lib/Dialect/TTIR/Transforms/Generic.cpp b/lib/Dialect/TTIR/Transforms/Generic.cpp index 005e12c079..3bf96f3cd6 100644 --- a/lib/Dialect/TTIR/Transforms/Generic.cpp +++ b/lib/Dialect/TTIR/Transforms/Generic.cpp @@ -257,7 +257,7 @@ class TTIRGenericRegionRewriter auto resEncoding = mlir::cast(op->getResult(0).getType()).getEncoding(); if (resEncoding) { - auto resLayout = mlir::cast(resEncoding); + auto resLayout = mlir::cast(resEncoding); gridAttr = resLayout.getGrid(); } @@ -339,7 +339,7 @@ struct TTIRGenericOperandsToMemrefRewriter auto matchingOperand = generic.getMatchingOperand(blockArgNumber); auto operandType = matchingOperand.getType(); - auto bufferLayout = mlir::cast( + auto bufferLayout = mlir::cast( mlir::cast(operandType).getEncoding()); auto bufferType = operandType; @@ -349,7 +349,7 @@ struct TTIRGenericOperandsToMemrefRewriter assert(static_cast(cbIndex) < generic.getCbs().size()); auto cb = generic.getCbs()[cbIndex]; auto cbType = cb.getType(); - auto cbLayout = mlir::cast( + auto cbLayout = mlir::cast( mlir::cast(cbType).getEncoding()); bufferLayout = cbLayout; bufferType = cbType; @@ -387,7 +387,7 @@ class TTIRGenericRegionMemrefTypeConverter : public TypeConverter { if (mlir::isa(encoding)) { return type; } - auto layout = mlir::cast(type.getEncoding()); + auto layout = mlir::cast(type.getEncoding()); auto buffer = BufferAttr::get(ctx, layout.getMemref(), BufferAccess::Alias); return RankedTensorType::get(buffer.getShape(), type.getElementType(), @@ -451,11 +451,11 @@ class TTIRGenericOpCBsRewriter : public OpRewritePattern { // Enforcing tiled layout as in kernel we always want to work with tiles. auto desiredElementType = rewriter.getType(ty.getElementType()); - auto desiredLayout = rewriter.getAttr( + auto desiredLayout = rewriter.getAttr( ty, MemorySpace::DeviceL1, generic.getGrid(), desiredElementType); auto operandTy = operand.getType(); - auto operandLayout = mlir::cast( + auto operandLayout = mlir::cast( mlir::cast(operandTy).getEncoding()); if (desiredLayout.getGrid() == operandLayout.getGrid()) { diff --git a/lib/Dialect/TTIR/Transforms/Layout.cpp b/lib/Dialect/TTIR/Transforms/Layout.cpp index d7eef6732d..c3ccbf1a44 100644 --- a/lib/Dialect/TTIR/Transforms/Layout.cpp +++ b/lib/Dialect/TTIR/Transforms/Layout.cpp @@ -38,20 +38,21 @@ class TTIRLayoutTensorTypeConverter : public TypeConverter { TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace, GridAttr deviceGrid) { addConversion([](Type type) { return type; }); - addConversion([ctx, initMemorySpace, - deviceGrid](RankedTensorType type) -> Type { - auto layout = type.getEncoding(); - if (layout) { - return type; - } - std::int64_t deviceGridRank = deviceGrid.getShape().size(); - // Default to single core grid - auto tensorGrid = GridAttr::get(ctx, deviceGridRank); - // Default to initMemorySpace, the optimizer might decide otherwise - auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace, tensorGrid); - return RankedTensorType::get(type.getShape(), type.getElementType(), - newLayout); - }); + addConversion( + [ctx, initMemorySpace, deviceGrid](RankedTensorType type) -> Type { + auto layout = type.getEncoding(); + if (layout) { + return type; + } + std::int64_t deviceGridRank = deviceGrid.getShape().size(); + // Default to single core grid + auto tensorGrid = GridAttr::get(ctx, deviceGridRank); + // Default to initMemorySpace, the optimizer might decide otherwise + auto newLayout = + MetalLayoutAttr::get(ctx, type, initMemorySpace, tensorGrid); + return RankedTensorType::get(type.getShape(), type.getElementType(), + newLayout); + }); } }; @@ -129,7 +130,7 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, TensorMemoryLayout desiredMemLayout, bool tiled) { auto ty = mlir::cast(input.getType()); - auto currLayout = mlir::cast(ty.getEncoding()); + auto currLayout = mlir::cast(ty.getEncoding()); auto currMemorySpace = currLayout.getMemorySpace(); auto currElementType = currLayout.getElementType(); auto currMemLayout = currLayout.getMemLayout(); @@ -142,9 +143,9 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, return std::nullopt; } - auto desiredLayout = - rewriter.getAttr(ty, desiredMemorySpace, currLayout.getGrid(), - desiredElementType, desiredMemLayout); + auto desiredLayout = rewriter.getAttr( + ty, desiredMemorySpace, currLayout.getGrid(), desiredElementType, + desiredMemLayout); tensor::EmptyOp existingEmpty = input.getDefiningOp(); if (existingEmpty) { @@ -343,7 +344,7 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; Value createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, - LayoutAttr desiredLayout) const { + MetalLayoutAttr desiredLayout) const { auto ty = mlir::cast(input.getType()); auto output = rewriter.create( loc, ty.getShape(), ty.getElementType(), desiredLayout); @@ -353,7 +354,7 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { } Value bounce(PatternRewriter &rewriter, ToLayoutOp op, - LayoutAttr bounceLayout) const { + MetalLayoutAttr bounceLayout) const { auto bounced = createToLayoutOp(rewriter, op.getLoc(), op.getInput(), bounceLayout); return rewriter.replaceOpWithNewOp( @@ -375,8 +376,8 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { auto inputType = mlir::cast(op.getInput().getType()); auto outputType = mlir::cast(op.getOutput().getType()); - auto inputLayout = mlir::cast(inputType.getEncoding()); - auto outputLayout = mlir::cast(outputType.getEncoding()); + auto inputLayout = mlir::cast(inputType.getEncoding()); + auto outputLayout = mlir::cast(outputType.getEncoding()); bool inputL1 = inputLayout.getMemorySpace() == MemorySpace::DeviceL1; bool outputL1 = outputLayout.getMemorySpace() == MemorySpace::DeviceL1; diff --git a/lib/Dialect/TTMetal/IR/TTMetalOps.cpp b/lib/Dialect/TTMetal/IR/TTMetalOps.cpp index 49baf51e01..7f78c1afcb 100644 --- a/lib/Dialect/TTMetal/IR/TTMetalOps.cpp +++ b/lib/Dialect/TTMetal/IR/TTMetalOps.cpp @@ -17,7 +17,7 @@ namespace mlir::tt::ttmetal { ::mlir::LogicalResult HostWriteOp::verify() { ::mlir::RankedTensorType outputTy = getOutput().getType(); auto outputLayout = - mlir::dyn_cast_or_null(outputTy.getEncoding()); + mlir::dyn_cast_or_null(outputTy.getEncoding()); if (not outputLayout) { return emitOpError("Input tensor missing layout attribute"); } @@ -30,7 +30,7 @@ ::mlir::LogicalResult HostWriteOp::verify() { ::mlir::LogicalResult HostReadOp::verify() { ::mlir::RankedTensorType outputTy = getOutput().getType(); auto outputLayout = - mlir::dyn_cast_or_null(outputTy.getEncoding()); + mlir::dyn_cast_or_null(outputTy.getEncoding()); if (not outputLayout) { return emitOpError("Input tensor missing layout attribute"); } @@ -41,7 +41,7 @@ ::mlir::LogicalResult HostReadOp::verify() { } ::mlir::LogicalResult AllocOp::verify() { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( getResult().getType().getEncoding()); if (not layout) { return emitOpError("Result type missing layout attribute"); @@ -76,7 +76,7 @@ ::mlir::LogicalResult AllocOp::verify() { ::mlir::LogicalResult DispatchOp::verify() { // Assert inputs/outputs device memspace for (auto operand : getOperands()) { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( mlir::cast(operand.getType()).getEncoding()); if (not layout) { return emitOpError("Input tensor missing layout attribute"); diff --git a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp index 47e15accf6..e82deaf633 100644 --- a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp +++ b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp @@ -62,18 +62,18 @@ memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, MemRefType memref, toFlatbuffer(cache, memLayout), size); } -flatbuffers::Offset<::tt::target::LayoutDesc> -layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, LayoutAttr layoutAttr, - ArrayRef logicalShape, DeviceAttr deviceAttr) { - auto strideInt64 = layoutAttr.getStride(logicalShape); +flatbuffers::Offset<::tt::target::LayoutDesc> metalLayoutAttrToFlatbuffer( + FlatbufferObjectCache &cache, MetalLayoutAttr metalLayoutAttr, + ArrayRef logicalShape, DeviceAttr deviceAttr) { + auto strideInt64 = metalLayoutAttr.getStride(logicalShape); std::vector stride(strideInt64.begin(), strideInt64.end()); - auto coreRangeSet = - toFlatbuffer(cache, layoutAttr.getGrid(), deviceAttr.getWorkerGrid()); + auto coreRangeSet = toFlatbuffer(cache, metalLayoutAttr.getGrid(), + deviceAttr.getWorkerGrid()); return ::tt::target::CreateLayoutDescDirect( - *cache.fbb, &stride, toFlatbuffer(cache, layoutAttr.getOobVal()), + *cache.fbb, &stride, toFlatbuffer(cache, metalLayoutAttr.getOobVal()), &coreRangeSet, - cache.getOrCreate(layoutAttr.getMemref(), memrefAttrToFlatbuffer, - layoutAttr.getMemLayout())); + cache.getOrCreate(metalLayoutAttr.getMemref(), memrefAttrToFlatbuffer, + metalLayoutAttr.getMemLayout())); } } // namespace mlir::tt @@ -277,7 +277,7 @@ static std::shared_ptr translateModuleToFlatbuffer( argumentAllocations[input.getArgNumber()]); assert( argAlloc.getMemorySpace() == - mlir::cast( + mlir::cast( mlir::cast(input.getType()).getEncoding()) .getMemorySpace() && "argument allocation memory space does not match tensor type " diff --git a/python/TTModule.cpp b/python/TTModule.cpp index c70d7df974..b8d543410c 100644 --- a/python/TTModule.cpp +++ b/python/TTModule.cpp @@ -16,14 +16,14 @@ namespace mlir::ttmlir::python { void populateTTModule(py::module &m) { - tt_attribute_class(m, "LayoutAttr") + tt_attribute_class(m, "MetalLayoutAttr") .def_static("get", [](MlirContext ctx, MlirType rankedTensorType, uint32_t memorySpaceValue, MlirAttribute grid, std::vector> collapseIntervals, uint32_t oobValValue, uint32_t memLayoutValue) { - return wrap(tt::LayoutAttr::get( + return wrap(tt::MetalLayoutAttr::get( unwrap(ctx), mlir::cast(unwrap(rankedTensorType)), static_cast(memorySpaceValue), @@ -37,7 +37,7 @@ void populateTTModule(py::module &m) { std::vector> collapseIntervals) { return wrap( - mlir::cast(unwrap(self)) + mlir::cast(unwrap(self)) .withGrid(unwrap(ctx), tensorShape, mlir::cast(unwrap(grid)), collapseIntervals)); @@ -47,7 +47,7 @@ void populateTTModule(py::module &m) { std::vector tensorShape, MlirAttribute grid, std::vector> collapseIntervals) { - return mlir::cast(unwrap(self)) + return mlir::cast(unwrap(self)) .withGrid(unwrap(ctx), tensorShape, mlir::cast(unwrap(grid)), collapseIntervals); @@ -55,13 +55,13 @@ void populateTTModule(py::module &m) { .def_static( "with_element_type", [](MlirContext ctx, MlirAttribute self, MlirType elementType) { - return wrap(mlir::cast(unwrap(self)) + return wrap(mlir::cast(unwrap(self)) .withElementType(unwrap(ctx), unwrap(elementType))); }) .def_static( "with_element_type_", [](MlirContext ctx, MlirAttribute self, MlirType elementType) { - return mlir::cast(unwrap(self)) + return mlir::cast(unwrap(self)) .withElementType(unwrap(ctx), unwrap(elementType)); }) .def("getLayout", @@ -73,38 +73,45 @@ void populateTTModule(py::module &m) { mlir::cast(unwrap(type)); assert(tensor.getEncoding()); // Make sure that this Tensor has an // encoding value - tt::LayoutAttr layout = - mlir::cast(tensor.getEncoding()); + tt::MetalLayoutAttr layout = + mlir::cast(tensor.getEncoding()); return layout; }) - .def("wrapped", [](tt::LayoutAttr const &self) { return wrap(self); }) - .def_property_readonly( - "stride", - [](tt::LayoutAttr const &self, std::vector logicalShape) { - auto stride = self.getStride(logicalShape); - return std::vector(stride.begin(), stride.end()); - }) - .def_property_readonly("oobval", &tt::LayoutAttr::getOobVal) + .def("wrapped", + [](tt::MetalLayoutAttr const &self) { return wrap(self); }) + .def_property_readonly("stride", + [](tt::MetalLayoutAttr const &self, + std::vector logicalShape) { + auto stride = self.getStride(logicalShape); + return std::vector(stride.begin(), + stride.end()); + }) + .def_property_readonly("oobval", &tt::MetalLayoutAttr::getOobVal) .def_property_readonly("oobval_as_int", - [](tt::LayoutAttr la) { + [](tt::MetalLayoutAttr la) { return static_cast(la.getOobVal()); }) - .def_property_readonly("grid_attr", &tt::LayoutAttr::getGrid) + .def_property_readonly("grid_attr", &tt::MetalLayoutAttr::getGrid) .def_property_readonly( - "memref", [](tt::LayoutAttr self) { return wrap(self.getMemref()); }) - .def_property_readonly("memory_space", &tt::LayoutAttr::getMemorySpace) + "memref", + [](tt::MetalLayoutAttr self) { return wrap(self.getMemref()); }) + .def_property_readonly("memory_space", + &tt::MetalLayoutAttr::getMemorySpace) .def_property_readonly("memory_space_as_int", - [](tt::LayoutAttr la) { + [](tt::MetalLayoutAttr la) { return static_cast( la.getMemorySpace()); }) - .def_property_readonly("shard_shape", &tt::LayoutAttr::getShardShape) - .def_property_readonly("memory_layout", &tt::LayoutAttr::getMemLayout) + .def_property_readonly("shard_shape", &tt::MetalLayoutAttr::getShardShape) + .def_property_readonly("memory_layout", + &tt::MetalLayoutAttr::getMemLayout) .def_property_readonly( - "linear", [](tt::LayoutAttr self) { return wrap(self.getLinear()); }) - .def_property_readonly("memory_layout_as_int", [](tt::LayoutAttr la) { - return static_cast(la.getMemLayout()); - }); + "linear", + [](tt::MetalLayoutAttr self) { return wrap(self.getLinear()); }) + .def_property_readonly("memory_layout_as_int", + [](tt::MetalLayoutAttr la) { + return static_cast(la.getMemLayout()); + }); tt_attribute_class(m, "GridAttr") .def_static("get", diff --git a/test/python/tensor_layout.py b/test/python/tensor_layout.py index 39a9a728be..2dbf249e9f 100644 --- a/test/python/tensor_layout.py +++ b/test/python/tensor_layout.py @@ -34,7 +34,7 @@ def createTensorLayout( shape, F32Type.get(ctx), None, Location.unknown(ctx) ) memoryLayout = getTensorMemoryLayout(memorySpace) - layout = tt.ir.LayoutAttr.get( + layout = tt.ir.MetalLayoutAttr.get( ctx, tensorTy, memorySpace, grid, collapseIntervals, oobVal, memoryLayout ) return RankedTensorType.get(shape, F32Type.get(ctx), layout, Location.unknown(ctx)) @@ -42,7 +42,7 @@ def createTensorLayout( def tilize(tensor, dataType, tileShape=[32, 32]): assert len(tileShape) == 2 - return tt.ir.LayoutAttr.with_element_type_( + return tt.ir.MetalLayoutAttr.with_element_type_( ctx, tensor.encoding, tt.ir.TileType.get(ctx, tileShape[0], tileShape[1], dataType), @@ -52,15 +52,15 @@ def tilize(tensor, dataType, tileShape=[32, 32]): def parallelize(tensor, grid, collapseIntervals=[(0, -1)]): if isinstance(grid, list) or isinstance(grid, tuple): grid = tt.ir.GridAttr.get(ctx, list(grid)) - return tt.ir.LayoutAttr.with_grid_( + return tt.ir.MetalLayoutAttr.with_grid_( ctx, tensor.encoding, tensor.shape, grid, collapseIntervals ) t0 = createTensorLayout([2, 3, 64, 128], [2, 4]) -# CHECK: tensor<2x3x64x128xf32, #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<192x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<2x3x64x128xf32, #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<192x32xf32, #tt.memory_space>, interleaved>> print(t0) -# CHECK: #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<6x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<6x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t0, tt.DataType.BFP_BFloat8).wrapped()) print(parallelize(t0, [3, 2]).wrapped()) @@ -69,24 +69,24 @@ def parallelize(tensor, grid, collapseIntervals=[(0, -1)]): print(parallelize(t1, [3, 2]).wrapped()) t2 = createTensorLayout([128], [4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (d0), undef, <4>, memref<32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (d0), undef, <4>, memref<32xf32, #tt.memory_space>, interleaved>> print(t2) -# CHECK: #tt.layout<(d0) -> (d0), undef, <2>, memref<64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (d0), undef, <2>, memref<64xf32, #tt.memory_space>, interleaved> print(parallelize(t2, [2]).wrapped()) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> print(parallelize(t2, [1, 2]).wrapped()) t3 = createTensorLayout([128], [1, 4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (0, d0), undef, <1x4>, memref<1x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (0, d0), undef, <1x4>, memref<1x32xf32, #tt.memory_space>, interleaved>> print(t3) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x4>, memref<1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x4>, memref<1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t3, tt.DataType.BFP_BFloat8).wrapped()) t4 = createTensorLayout([128], [1, 2, 4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x32xf32, #tt.memory_space>, interleaved>> print(t4) -# CHECK: #tt.layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t4, tt.DataType.BFP_BFloat8).wrapped()) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> print(parallelize(t4, [1, 2]).wrapped()) diff --git a/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir b/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir index 2335fb0df3..42cab3d1f6 100644 --- a/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir +++ b/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir @@ -3,21 +3,21 @@ #dram = #tt.memory_space #l1_ = #tt.memory_space -// CHECK-DAG: #[[row_major1x1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -// CHECK-DAG: #[[row_major1x1_T:.*]] = #tt.layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -// CHECK-DAG: #[[row_major2x2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_f32:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_bf16:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_f32_dram:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> -// CHECK-DAG: #[[tile2x2_f32:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> +// CHECK-DAG: #[[row_major1x1:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +// CHECK-DAG: #[[row_major1x1_T:.*]] = #tt.metal_layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +// CHECK-DAG: #[[row_major2x2:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_f32:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_bf16:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_f32_dram:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> +// CHECK-DAG: #[[tile2x2_f32:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> -#row_major1x1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -#row_major1x1_T = #tt.layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -#row_major2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> -#tile1x1_f32 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> -#tile1x1_bf16 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> -#tile1x1_f32_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> -#tile2x2_f32 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> +#row_major1x1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#row_major1x1_T = #tt.metal_layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#row_major2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> +#tile1x1_f32 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> +#tile1x1_bf16 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> +#tile1x1_f32_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> +#tile2x2_f32 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> func.func @noncompound_linear(%in: tensor<64x128xf32, #row_major1x1>) -> tensor<64x128xf32, #row_major1x1_T> { %out = tensor.empty() : tensor<64x128xf32, #row_major1x1_T> diff --git a/test/ttmlir/Dialect/TTIR/test_allocate.mlir b/test/ttmlir/Dialect/TTIR/test_allocate.mlir index a80a8c1c91..5888cf3f62 100644 --- a/test/ttmlir/Dialect/TTIR/test_allocate.mlir +++ b/test/ttmlir/Dialect/TTIR/test_allocate.mlir @@ -1,7 +1,7 @@ // RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-allocate %s | FileCheck %s #any_device = #tt.operand_constraint #l1_ = #tt.memory_space -#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#layout = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> module attributes {} { func.func @forward(%arg0: tensor<64x128xf32, #layout>, %arg1: tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> { // CHECK: %[[C:.*]] = "ttir.alloc"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir index 1674ae0d32..cdde621c2a 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir @@ -1,8 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline="system-desc-path=%system_desc_path%" %s | FileCheck %s #any_device = #tt.operand_constraint #l1_ = #tt.memory_space -#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <4x4>, memref<64x96xf32, #l1_>> -#layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>> +#layout1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x4>, memref<64x96xf32, #l1_>> +#layout2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>> func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, #layout2> { %0 = tensor.empty() : tensor<256x32xf32, #layout2> @@ -15,7 +15,7 @@ func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, # return %1 : tensor<256x32xf32, #layout2> } -#layout3 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<32x96xf32, #l1_>> +#layout3 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<32x96xf32, #l1_>> func.func @reduceH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x384xf32, #layout3> { %0 = tensor.empty() : tensor<32x384xf32, #layout3> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -27,7 +27,7 @@ func.func @reduceH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x384xf32, # return %1 : tensor<32x384xf32, #layout3> } -#layout4 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<32x32xf32, #l1_>> +#layout4 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<32x32xf32, #l1_>> func.func @reduceWH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x32xf32, #layout4> { %0 = tensor.empty() : tensor<32x32xf32, #layout4> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir b/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir index 64cf5f57a6..d7d3cea1dd 100644 --- a/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir +++ b/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir @@ -4,10 +4,10 @@ #l1_ = #tt.memory_space -#untilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> -#tilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> -#tilized2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32 x 32, f32>, #l1_>> -#untilized2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>> +#untilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> +#tilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> +#tilized2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32 x 32, f32>, #l1_>> +#untilized2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>> func.func @tilize_reblock_2D(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, #untilized2x2> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32, #tilized> @@ -25,10 +25,10 @@ func.func @tilize_reblock_2D(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64 } -#untilized4D = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<384x128xf32, #l1_>> -#tilized4D = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<12x4x!tt.tile<32 x 32, f32>, #l1_>> -#tilized4D_2x2 = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<6x2x!tt.tile<32 x 32, f32>, #l1_>> -#untilized4D_2x2 = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<192x64xf32, #l1_>> +#untilized4D = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<384x128xf32, #l1_>> +#tilized4D = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<12x4x!tt.tile<32 x 32, f32>, #l1_>> +#tilized4D_2x2 = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<6x2x!tt.tile<32 x 32, f32>, #l1_>> +#untilized4D_2x2 = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<192x64xf32, #l1_>> func.func @tilize_reblock_4D(%arg0: tensor<2x3x64x128xf32, #untilized4D>) -> tensor<2x3x64x128xf32, #untilized4D_2x2> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<2x3x64x128xf32, #tilized4D> @@ -48,10 +48,10 @@ func.func @tilize_reblock_4D(%arg0: tensor<2x3x64x128xf32, #untilized4D>) -> ten return %5 : tensor<2x3x64x128xf32, #untilized4D_2x2> } -#untilized_big = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<96x192xf32, #l1_>> -#tilized_big = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<3x6x!tt.tile<32 x 32, f32>, #l1_>> -#tilized_big_3x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <3x2>, memref<1x3x!tt.tile<32 x 32, f32>, #l1_>> -#tilized_big_3x6 = #tt.layout<(d0, d1) -> (d0, d1), undef, <3x6>, memref<1x1x!tt.tile<32 x 32, f32>, #l1_>> +#untilized_big = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<96x192xf32, #l1_>> +#tilized_big = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<3x6x!tt.tile<32 x 32, f32>, #l1_>> +#tilized_big_3x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <3x2>, memref<1x3x!tt.tile<32 x 32, f32>, #l1_>> +#tilized_big_3x6 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <3x6>, memref<1x1x!tt.tile<32 x 32, f32>, #l1_>> func.func @tilize_reblock_big(%arg0: tensor<96x192xf32, #untilized_big>) -> tensor<96x192xf32, #untilized_big> { // move to tilized 1x1 // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTMetal/to_layout.mlir b/test/ttmlir/Silicon/TTMetal/to_layout.mlir index 015e651750..e5318c6c1d 100644 --- a/test/ttmlir/Silicon/TTMetal/to_layout.mlir +++ b/test/ttmlir/Silicon/TTMetal/to_layout.mlir @@ -5,8 +5,8 @@ #l1_ = #tt.memory_space #dram = #tt.memory_space -#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<4x16xf32, #l1_>> -#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<2x8xf32, #l1_>> +#layout = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<4x16xf32, #l1_>> +#layout1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<2x8xf32, #l1_>> func.func @simple(%arg0: tensor<4x16xf32, #layout>) -> tensor<4x16xf32, #layout1> { %0 = tensor.empty() : tensor<4x16xf32, #layout1> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -14,8 +14,8 @@ func.func @simple(%arg0: tensor<4x16xf32, #layout>) -> tensor<4x16xf32, #layout1 return %1 : tensor<4x16xf32, #layout1> } -#untilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> -#tilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> +#untilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> +#tilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> func.func @tilize(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, #untilized> { %0 = tensor.empty() : tensor<64x128xf32, #tilized> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -26,11 +26,11 @@ func.func @tilize(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, # return %3 : tensor<64x128xf32, #untilized> } -#untilized_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #dram>> -#untilized_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #l1_>> -#untilized2x2_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #dram>> -#untilized2x2_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #l1_>> -#untilized1x4_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<16x16xf32, #l1_>> +#untilized_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #dram>> +#untilized_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #l1_>> +#untilized2x2_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #dram>> +#untilized2x2_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #l1_>> +#untilized1x4_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<16x16xf32, #l1_>> func.func @dram_to_l1(%arg0: tensor<16x64xf32, #untilized_dram>) -> tensor<16x64xf32, #untilized_l1> { %0 = tensor.empty() : tensor<16x64xf32, #untilized_l1> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir index ba995925d5..0193ec36b1 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir @@ -4,8 +4,8 @@ #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded> - // CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded> + // CHECK: #[[LAYOUT_10:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded> + // CHECK: #[[LAYOUT_11:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded> %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py index 5233e844c2..b9ae471ca5 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -348,7 +348,7 @@ def parse_dimension(attr): @AttrHandler.register_handler("tt.layout") def parse_tt_layout(attr): - layout = tt.ir.LayoutAttr.maybe_downcast(attr) + layout = tt.ir.MetalLayoutAttr.maybe_downcast(attr) result = [] result.append(graph_builder.KeyValue(key="linear", value=str(layout.linear))) result.append(