Skip to content

Commit

Permalink
Add simple test to be used for emitc path (#1217)
Browse files Browse the repository at this point in the history
* Add simple test to be used for emitc path

* fix failed filecheck

* fix broken include

* fix force attr for deallocate

* straggler files added

* fix dealloc test

* use deallocate from ttnn namespace; read force param
  • Loading branch information
svuckovicTT authored Nov 13, 2024
1 parent 430b036 commit cb1e6fc
Show file tree
Hide file tree
Showing 14 changed files with 151 additions and 91 deletions.
9 changes: 5 additions & 4 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -701,13 +701,14 @@ def TTNN_AllocOp : TTNN_Op<"alloc"> {
let hasVerifier = 1;
}

def TTNN_DeallocOp : TTNN_Op<"dealloc"> {
let summary = "Dealloc op.";
def TTNN_DeallocateOp : TTNN_Op<"deallocate"> {
let summary = "Deallocate op.";
let description = [{
Tensor Dealloc operation
Tensor Deallocate operation
}];

let arguments = (ins AnyRankedTensor:$input);
let arguments = (ins AnyRankedTensor:$input,
DefaultValuedAttr<BoolAttr, "false">:$force);
}

def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
Expand Down
5 changes: 3 additions & 2 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ table MaxPool2dOp {
padding_width: uint32;
}

table DeallocOp {
table DeallocateOp {
in: tt.target.TensorRef;
force: bool;
}

table AllGatherOp {
Expand Down Expand Up @@ -244,7 +245,7 @@ union OpType {
ReshapeOp,
SliceOp,
MaxPool2dOp,
DeallocOp,
DeallocateOp,
AllGatherOp,
}

Expand Down
88 changes: 59 additions & 29 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ emitc::OpaqueAttr convertLayoutAttr(Builder &builder, ttnn::LayoutAttr attr) {
llvm_unreachable("Unknown ttnn::Layout");
}

emitc::OpaqueAttr convertBoolAttr(Builder &builder, BoolAttr attr) {
return builder.getType<emitc::OpaqueAttr>(attr.getValue() ? "true" : "false");
}

// Create emitc::OpaqueAttr for ttnn::TensorMemoryLayout
//
emitc::OpaqueAttr convertTensorMemoryLayout(Builder &builder,
Expand Down Expand Up @@ -220,26 +224,25 @@ class DefaultOpConversionPattern
}
};

// MultiplyOp conversion pattern
// Eltwise Binary op conversion pattern
//
// TODO(bug #623):
// Convert all DPS-supported ttnn ops to this conversion pattern (nullopts added
// for correct signature)
// Currently, it has to insert nullopts for some parameters that are not
// modelled in the dialect (output dtype, memcfg)
//
class MultiplyOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<ttnn::MultiplyOp> {
template <typename SourceOp, typename Adaptor = typename SourceOp::Adaptor>
class EltwiseBinaryOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<SourceOp> {

public:
MultiplyOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: TTNNToEmitCBaseOpConversionPattern<ttnn::MultiplyOp>(typeConverter,
context, benefit) {
}
EltwiseBinaryOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
: TTNNToEmitCBaseOpConversionPattern<SourceOp>(typeConverter, context,
benefit) {}

LogicalResult
matchAndRewrite(ttnn::MultiplyOp srcOp, OpAdaptor adaptor,
matchAndRewrite(SourceOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// emitc::CallOpaqueOp needs to know positions of operands vs attributes, so
// an ArrayAttr object holding IndexTypes is created to denote this
//
Expand Down Expand Up @@ -502,7 +505,7 @@ class EmptyOpConversionPattern
// specific datatypes on input. However, one of the constructors takes in a
// tt_metal::Shape - given that it's much easier to construct a
// tt_metal::Shape, we opted to do that here. The call looks like this:
// ttnn::Shape(tt::tt_metal::Shape{dim0, dim1, dim2, ...});
// ttnn::Shape(tt::tt_metal::LegacyShape{dim0, dim1, dim2, ...});
//
// To make it easier on the eyes, these two calls are packed into one, using
// EmitC's ExpressionOp.
Expand All @@ -517,8 +520,9 @@ class EmptyOpConversionPattern
rewriter.setInsertionPointToStart(&bodyBlock);
emitc::CallOpaqueOp metalShapeOp = rewriter.create<emitc::CallOpaqueOp>(
srcOp->getLoc(),
emitc::OpaqueType::get(rewriter.getContext(), "tt::metal::Shape"),
rewriter.getStringAttr("Shape"),
emitc::OpaqueType::get(rewriter.getContext(),
"tt::tt_metal::LegacyShape"),
rewriter.getStringAttr("tt::tt_metal::LegacyShape"),
rewriter.getArrayAttr(convertShape(rewriter, shapeAttr)), nullptr,
ValueRange());
emitc::CallOpaqueOp ttnnShapeOp = rewriter.create<emitc::CallOpaqueOp>(
Expand All @@ -537,10 +541,7 @@ class EmptyOpConversionPattern
//
ArrayAttr arrayAttr;
if (adaptor.getDevice()) {
mlir::emitc::ApplyOp derefDevice = rewriter.create<emitc::ApplyOp>(
srcOp->getLoc(), rewriter.getType<emitc::OpaqueType>("ttnn::Device&"),
"*", adaptor.getDevice());
operands.append(1, derefDevice->getResult(0));
operands.append(1, adaptor.getDevice());

// Create ArrayAttr object holding MemoryConfig attributes
//
Expand Down Expand Up @@ -588,6 +589,35 @@ class EmptyOpConversionPattern
}
};

// DeallocateOp conversion pattern
//
class DeallocateOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<ttnn::DeallocateOp> {

public:
DeallocateOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
: TTNNToEmitCBaseOpConversionPattern<ttnn::DeallocateOp>(
typeConverter, context, benefit) {}

LogicalResult
matchAndRewrite(ttnn::DeallocateOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

ArrayAttr arrayAttr = rewriter.getArrayAttr({
rewriter.getIndexAttr(0),
convertBoolAttr(rewriter, srcOp.getForceAttr()),
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, srcOp->getResultTypes(), this->convertOpName(srcOp), arrayAttr,
nullptr, adaptor.getOperands());

return success();
}
};

} // namespace

namespace mlir::tt {
Expand All @@ -601,10 +631,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,

// Memory ops
//
patterns.add<ToLayoutOpConversionPattern, TypecastOpConversionPattern,
ToDeviceOpConversionPattern, FromDeviceOpConversionPattern,
DefaultOpConversionPattern<ttnn::DeallocOp>,
ToMemoryConfigOpConversionPattern>(typeConverter, ctx);
patterns.add<ToLayoutOpConversionPattern, ToMemoryConfigOpConversionPattern,
TypecastOpConversionPattern, ToDeviceOpConversionPattern,
FromDeviceOpConversionPattern, DeallocateOpConversionPattern>(
typeConverter, ctx);

// Tensor ops
//
Expand Down Expand Up @@ -636,11 +666,11 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,

// Eltwise binary ops
//
patterns.add<DefaultOpConversionPattern<ttnn::AddOp>,
DefaultOpConversionPattern<ttnn::LogicalAndOp>,
DefaultOpConversionPattern<ttnn::LogicalOrOp>,
DefaultOpConversionPattern<ttnn::SubtractOp>,
MultiplyOpConversionPattern,
patterns.add<EltwiseBinaryOpConversionPattern<ttnn::AddOp>,
EltwiseBinaryOpConversionPattern<ttnn::LogicalAndOp>,
EltwiseBinaryOpConversionPattern<ttnn::LogicalOrOp>,
EltwiseBinaryOpConversionPattern<ttnn::SubtractOp>,
EltwiseBinaryOpConversionPattern<ttnn::MultiplyOp>,
DefaultOpConversionPattern<ttnn::EqualOp>,
DefaultOpConversionPattern<ttnn::NotEqualOp>,
DefaultOpConversionPattern<ttnn::GreaterEqualOp>,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
}

rewriter.setInsertionPointAfter(lastOp);
rewriter.create<DeallocOp>(lastOp->getLoc(), result);
rewriter.create<DeallocateOp>(lastOp->getLoc(), result);
}
});
});
Expand Down
13 changes: 7 additions & 6 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,13 @@ createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) {
return ::tt::target::ttnn::CreateSoftmaxOp(*cache.fbb, in, out, dimension);
}

template <typename DeallocOp>
::flatbuffers::Offset<::tt::target::ttnn::DeallocOp>
createDeallocOp(FlatbufferObjectCache &cache, DeallocOp op) {
template <typename DeallocateOp>
::flatbuffers::Offset<::tt::target::ttnn::DeallocateOp>
createDeallocateOp(FlatbufferObjectCache &cache, DeallocateOp op) {
auto in =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
return ::tt::target::ttnn::CreateDeallocOp(*cache.fbb, in);
auto force = op.getForceAttr().getValue();
return ::tt::target::ttnn::CreateDeallocateOp(*cache.fbb, in, force);
}

::flatbuffers::Offset<::tt::target::ttnn::Operation>
Expand Down Expand Up @@ -708,8 +709,8 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createMaxPool2dOp(cache, max_pool2dOp),
debugString);
}
if (auto deallocOp = dyn_cast<DeallocOp>(op); deallocOp) {
return createOperation(cache, createDeallocOp(cache, deallocOp),
if (auto deallocateOp = dyn_cast<DeallocateOp>(op); deallocateOp) {
return createOperation(cache, createDeallocateOp(cache, deallocateOp),
debugString);
}
if (auto ceilOp = dyn_cast<CeilOp>(op); ceilOp) {
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/slice.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/transpose.cpp
${CMAKE_CURRENT_SOURCE_DIR}/deletion/dealloc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/deletion/deallocate.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "dealloc.h"
#include "deallocate.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"

namespace tt::runtime::ttnn::operations::deletion {
void run(const ::tt::target::ttnn::DeallocOp *op, ProgramContext &context) {
void run(const ::tt::target::ttnn::DeallocateOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
::ttnn::Tensor &tensor = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(tensor.is_allocated());
tensor.deallocate();
::ttnn::deallocate(tensor, op->force());

// The tensor should be deallocated after the deallocate call.
// Still this assert may be hit in the future for multidevice/async ttnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_DEALLOC_H
#define TTNN_RUNTIME_DEALLOC_H
#ifndef TTNN_RUNTIME_DEALLOCATE_H
#define TTNN_RUNTIME_DEALLOCATE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::deletion {
void run(const ::tt::target::ttnn::DeallocOp *op, ProgramContext &context);
void run(const ::tt::target::ttnn::DeallocateOp *op, ProgramContext &context);
} // namespace tt::runtime::ttnn::operations::deletion

#endif
6 changes: 3 additions & 3 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "operations/data_movement/reshape.h"
#include "operations/data_movement/slice.h"
#include "operations/data_movement/transpose.h"
#include "operations/deletion/dealloc.h"
#include "operations/deletion/deallocate.h"
#include "operations/eltwise/binary/binary.h"
#include "operations/eltwise/binary/binary_composite.h"
#include "operations/eltwise/ternary/ternary.h"
Expand Down Expand Up @@ -149,8 +149,8 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) {
case ::tt::target::ttnn::OpType::Conv2dOp: {
return operations::conv::run(op->type_as_Conv2dOp(), context);
}
case ::tt::target::ttnn::OpType::DeallocOp: {
return operations::deletion::run(op->type_as_DeallocOp(), context);
case ::tt::target::ttnn::OpType::DeallocateOp: {
return operations::deletion::run(op->type_as_DeallocateOp(), context);
}
case ::tt::target::ttnn::OpType::MaxPool2dOp: {
return operations::pool::run(op->type_as_MaxPool2dOp(), context);
Expand Down
8 changes: 7 additions & 1 deletion runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def generate_subparser(subparsers):
return run_parser

class TorchInitializer:
init_fns = sorted(["randn", "arange", "zeros"])
init_fns = sorted(["randn", "arange", "zeros", "ones"])

@staticmethod
def get_initilizer(name):
Expand Down Expand Up @@ -660,3 +660,9 @@ def zeros(shape, dtype):
import torch

return torch.zeros(shape, dtype=dtype)

@staticmethod
def ones(shape, dtype):
import torch

return torch.ones(shape, dtype=dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@ module @"dealloc_test" attributes {} {
%0 = tensor.empty() : tensor<1x256xf32> loc(#loc8)
%1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8)
// CHECK: %{{.+}} = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]], [[O1:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}>
// CHECK: "ttnn.dealloc"([[I2]]) : (tensor<784x256xf32, {{.+}}) -> ()
// CHECK: "ttnn.dealloc"([[I1]]) : (tensor<1x784xf32, {{.+}}>) -> ()
// CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<784x256xf32, {{.+}}) -> ()
// CHECK: "ttnn.deallocate"([[I1]]) {{.+}} : (tensor<1x784xf32, {{.+}}>) -> ()
%2 = tensor.empty() : tensor<1x256xf32> loc(#loc9)
%3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9)
// CHECK: %{{.+}} = "ttnn.add"([[I1:%.+]], [[I2:%.+]], [[O2:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}>
// CHECK: "ttnn.dealloc"([[I2]]) : (tensor<1x256xf32, {{.+}}>) -> ()
// CHECK: "ttnn.dealloc"([[O1]]) : (tensor<1x256xf32, {{.+}}>) -> ()
// CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> ()
// CHECK: "ttnn.deallocate"([[O1]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> ()
%4 = tensor.empty() : tensor<1x256xf32> loc(#loc10)
%5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10)
// CHECK: %{{.+}} = "ttnn.relu"([[I1:%.+]], [[O3:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}>
// CHECK: "ttnn.dealloc"([[O2]]) : (tensor<1x256xf32, {{.+}}>) -> ()
// CHECK: "ttnn.deallocate"([[O2]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> ()
%6 = tensor.empty() : tensor<1x10xf32> loc(#loc11)
%7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11)
// CHECK: %{{.+}} = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]], [[O4:%.+]]) {{.+}} -> tensor<1x10xf32, {{.+}}>
// CHECK: "ttnn.dealloc"([[I2]]) : (tensor<256x10xf32, {{.+}}>) -> ()
// CHECK: "ttnn.dealloc"([[O3]]) : (tensor<1x256xf32,{{.+}}>) -> ()
// CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<256x10xf32, {{.+}}>) -> ()
// CHECK: "ttnn.deallocate"([[O3]]) {{.+}} : (tensor<1x256xf32,{{.+}}>) -> ()
%8 = tensor.empty() : tensor<1x10xf32> loc(#loc12)
%9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12)
// CHECK: %{{.+}} = "ttnn.add"([[I1:%.+]], [[I2:%.+]], [[O5:%.+]]) {{.+}} -> tensor<1x10xf32,{{.+}}>
// CHECK: "ttnn.dealloc"([[I2]]) : (tensor<1x10xf32, {{.+}}>) -> ()
// CHECK: "ttnn.dealloc"([[O4]]) : (tensor<1x10xf32, {{.+}}>) -> ()
// CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> ()
// CHECK: "ttnn.deallocate"([[O4]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> ()
%10 = tensor.empty() : tensor<1x10xf32> loc(#loc13)
%11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13)
return %11 : tensor<1x10xf32> loc(#loc7)
Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>

func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
%0 = tensor.empty() : tensor<32x32xbf16>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %1 : tensor<32x32xbf16>
}
7 changes: 5 additions & 2 deletions tools/ttnn-standalone/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,17 @@ add_executable(ttnn-standalone ttnn-standalone.cpp)
set_property(TARGET ttnn-standalone PROPERTY CXX_STANDARD 20)

target_include_directories(ttnn-standalone PRIVATE
# reflect
# TODO: Remove this when ttmetal removes this dependency (reflect) from public facing headers
# TODO: Remove these when ttmetal removes the dependencies from public facing headers
$ENV{TT_METAL_HOME}/.cpmcache/reflect/e75434c4c5f669e4a74e4d84e0a30d7249c1e66f
$ENV{TT_METAL_HOME}/.cpmcache/fmt/73b5ec45edbd92babfd91c3777a9e1ab9cac8238/include
$ENV{TT_METAL_HOME}/.cpmcache/magic_enum/1e1af177d4ab0ef660f105434fd1017c4d1f8c17/include/magic_enum
$ENV{TT_METAL_HOME}/.cpmcache/boost_core/e679bef5c160cf29d0f37d549881dc5f5a58c332/include

# Metalium
$ENV{TT_METAL_HOME}
$ENV{TT_METAL_HOME}/tt_metal
$ENV{TT_METAL_HOME}/tt_metal/third_party/umd
$ENV{TT_METAL_HOME}/tt_metal/third_party/umd/device
$ENV{TT_METAL_HOME}/tt_metal/third_party/fmt
$ENV{TT_METAL_HOME}/tt_metal/hw/inc
$ENV{TT_METAL_HOME}/tt_metal/hw/inc/${ARCH_NAME}
Expand Down
Loading

0 comments on commit cb1e6fc

Please sign in to comment.