Skip to content

Commit

Permalink
Added TTNN Bindings
Browse files Browse the repository at this point in the history
updated imports

Added TTNN Ops, Attrs, & Enums to Python Bindings
  • Loading branch information
vprajapati-tt committed Oct 21, 2024
1 parent 43c63a5 commit cc5f60e
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/ttmlir-c/Dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TT, tt);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TTIR, ttir);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TTKernel, ttkernel);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TTNN, ttnn);

#ifdef __cplusplus
}
Expand Down
48 changes: 48 additions & 0 deletions include/ttmlir-c/TTNNAttrs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_C_TTNNATTRS_H
#define TTMLIR_C_TTNNATTRS_H

#include "ttmlir-c/Dialects.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNCoreRangeAttrGet(MlirContext ctx,
int64_t *offset,
size_t offsetSize,
int64_t *size,
size_t sizeSize);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNCoreRangeArrayAttrGet(
MlirContext ctx, MlirAttribute *coreRangeAttrs, size_t coreRangeAttrsSize);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNLayoutAttrGet(MlirContext ctx,
uint32_t layout);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNTensorMemoryLayoutAttrGet(
MlirContext ctx, uint32_t tensorMemoryLayout);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTNNBufferTypeAttrGet(MlirContext ctx, uint32_t bufferType);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNMemoryConfigAttrGet(
MlirContext ctx, MlirAttribute tensorMemoryLayoutAttr,
MlirAttribute bufferTypeAttr);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNShapeAttrGet(MlirContext ctx,
int64_t *shape,
size_t shapeSize);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNMeshShapeAttrGet(MlirContext ctx,
int64_t y,
int64_t x);

#ifdef __cplusplus
}
#endif

#endif // TTMLIR_C_TTNNATTRS_H
1 change: 1 addition & 0 deletions include/ttmlir/Bindings/Python/TTMLIRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace py = pybind11;
namespace mlir::ttmlir::python {
void populateTTModule(py::module &m);
void populateTTKernelModule(py::module &m);
void populateTTNNModule(py::module &m);
void populateOverridesModule(py::module &m);
void populatePassesModule(py::module &m);
} // namespace mlir::ttmlir::python
Expand Down
2 changes: 2 additions & 0 deletions lib/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_public_c_api_library(TTMLIRCAPI
TTKernelTypes.cpp
TTAttrs.cpp
TTTypes.cpp
TTNNAttrs.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir-c/
Expand All @@ -16,6 +17,7 @@ add_mlir_public_c_api_library(TTMLIRCAPI
MLIRTTDialect
MLIRTTIRDialect
MLIRTTKernelDialect
MLIRTTNNDialect
MLIRTTIRTransforms
MLIRTTIRAnalysis
)
2 changes: 2 additions & 0 deletions lib/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTKernel/IR/TTKernel.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(TT, tt, mlir::tt::TTDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(TTIR, ttir, mlir::tt::ttir::TTIRDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(TTKernel, ttkernel,
mlir::tt::ttkernel::TTKernelDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(TTNN, ttnn, mlir::tt::ttnn::TTNNDialect)
66 changes: 66 additions & 0 deletions lib/CAPI/TTNNAttrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir-c/TTNNAttrs.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

namespace mlir::tt::ttnn {

MlirAttribute ttmlirTTNNCoreRangeAttrGet(MlirContext ctx, int64_t *offset,
size_t offsetSize, int64_t *size,
size_t sizeSize) {
return wrap(CoreRangeAttr::get(unwrap(ctx), {offset, offset + offsetSize},
{size, size + sizeSize}));
}

MlirAttribute ttmlirTTNNCoreRangeArrayAttrGet(MlirContext ctx,
MlirAttribute *coreRangeAttrs,
size_t coreRangeAttrsSize) {
std::vector<mlir::Attribute> coreRanges;
for (size_t i = 0; i < coreRangeAttrsSize; i++) {
coreRanges.push_back(mlir::cast<CoreRangeAttr>(unwrap(coreRangeAttrs[i])));
}
return wrap(ArrayAttr::get(unwrap(ctx), coreRanges));
}

MlirAttribute ttmlirTTNNLayoutAttrGet(MlirContext ctx, uint32_t layout) {
return wrap(LayoutAttr::get(unwrap(ctx), static_cast<Layout>(layout)));
}

MlirAttribute ttmlirTTNNTensorMemoryLayoutAttrGet(MlirContext ctx,
uint32_t tensorMemoryLayout) {
return wrap(TensorMemoryLayoutAttr::get(
unwrap(ctx), static_cast<TensorMemoryLayout>(tensorMemoryLayout)));
}

MlirAttribute ttmlirTTNNBufferTypeAttrGet(MlirContext ctx,
uint32_t bufferType) {
return wrap(
BufferTypeAttr::get(unwrap(ctx), static_cast<BufferType>(bufferType)));
}

MlirAttribute
ttmlirTTNNMemoryConfigAttrGet(MlirContext ctx,
MlirAttribute tensorMemoryLayoutAttr,
MlirAttribute bufferTypeAttr) {
return wrap(MemoryConfigAttr::get(
unwrap(ctx),
mlir::cast<TensorMemoryLayoutAttr>(unwrap(tensorMemoryLayoutAttr)),
mlir::cast<BufferTypeAttr>(unwrap(bufferTypeAttr))));
}

MlirAttribute ttmlirTTNNShapeAttrGet(MlirContext ctx, int64_t *shape,
size_t shapeSize) {
return wrap(ShapeAttr::get(unwrap(ctx), {shape, shape + shapeSize}));
}

MlirAttribute ttmlirTTNNMeshShapeAttrGet(MlirContext ctx, int64_t y,
int64_t x) {
return wrap(MeshShapeAttr::get(unwrap(ctx), y, x));
}

} // namespace mlir::tt::ttnn
13 changes: 12 additions & 1 deletion python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}"
TD_FILE dialects/TTBinding.td
GEN_ENUM_BINDINGS ON
GEN_ENUM_TD_FILE dialects/TTEnumBindings.td
GEN_ENUM_TD_FILE dialects/TTEnumBinding.td
SOURCES dialects/tt.py
DIALECT_NAME tt
)
Expand All @@ -47,6 +47,16 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME ttkernel
)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT TTMLIRPythonSources.Dialects
ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}"
TD_FILE dialects/TTNNBinding.td
GEN_ENUM_BINDINGS ON
GEN_ENUM_TD_FILE dialects/TTNNEnumBinding.td
SOURCES dialects/ttnn.py
DIALECT_NAME ttnn
)

declare_mlir_python_sources(TTMLIRPythonSources.Overrides
ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TTMLIRPythonSources
Expand All @@ -72,6 +82,7 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main
TTMLIRModule.cpp
TTModule.cpp
TTKernelModule.cpp
TTNNModule.cpp
Overrides.cpp
Passes.cpp
EMBED_CAPI_LINK_LIBS
Expand Down
5 changes: 5 additions & 0 deletions python/TTMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ PYBIND11_MODULE(_ttmlir, m) {
MlirDialectHandle tt_handle = mlirGetDialectHandle__tt__();
MlirDialectHandle ttir_handle = mlirGetDialectHandle__ttir__();
MlirDialectHandle ttkernel_handle = mlirGetDialectHandle__ttkernel__();
MlirDialectHandle ttnn_handle = mlirGetDialectHandle__ttnn__();
mlirDialectHandleRegisterDialect(tt_handle, context);
mlirDialectHandleRegisterDialect(ttir_handle, context);
mlirDialectHandleRegisterDialect(ttkernel_handle, context);
mlirDialectHandleRegisterDialect(ttnn_handle, context);
if (load) {
mlirDialectHandleLoadDialect(tt_handle, context);
mlirDialectHandleLoadDialect(ttir_handle, context);
mlirDialectHandleLoadDialect(ttkernel_handle, context);
mlirDialectHandleLoadDialect(ttnn_handle, context);
}
},
py::arg("context"), py::arg("load") = true);
Expand All @@ -28,6 +31,8 @@ PYBIND11_MODULE(_ttmlir, m) {
mlir::ttmlir::python::populateTTModule(tt_ir);
auto ttkernel_ir = m.def_submodule("ttkernel_ir", "TTKernel IR Bindings");
mlir::ttmlir::python::populateTTKernelModule(ttkernel_ir);
auto ttnn_ir = m.def_submodule("ttnn_ir", "TTNN IR Bindings");
mlir::ttmlir::python::populateTTNNModule(ttnn_ir);
auto overrides = m.def_submodule("overrides", "Python-Bound Overrides");
mlir::ttmlir::python::populateOverridesModule(overrides);
auto passes =
Expand Down
104 changes: 104 additions & 0 deletions python/TTNNModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Bindings/Python/TTMLIRModule.h"

namespace mlir::ttmlir::python {
void populateTTNNModule(py::module &m) {

py::class_<tt::ttnn::CoreRangeAttr>(m, "CoreRangeAttr")
.def_static("get",
[](MlirContext ctx, std::vector<int64_t> offset,
std::vector<int64_t> size) {
return wrap(tt::ttnn::CoreRangeAttr::get(unwrap(ctx),
offset, size));
})
.def_static(
"get_with_grid",
[](MlirContext ctx, MlirAttribute grid, std::vector<int64_t> offset) {
llvm::SmallVector<int64_t> offsetVec{0, 0};
if (offset.size() == 2 && not(offset[0] == 0 && offset[1] == 0)) {
offsetVec[0] = offset[0];
offsetVec[1] = offset[1];
}
return wrap(tt::ttnn::CoreRangeAttr::get(
unwrap(ctx), mlir::cast<tt::GridAttr>(unwrap(grid)),
offsetVec));
},
py::arg("ctx"), py::arg("grid"),
py::arg("offset") = std::vector<int64_t>{0, 0});
py::class_<tt::ttnn::LayoutAttr>(m, "LayoutAttr")
.def_static("get",
[](MlirContext ctx, uint32_t layout) {
return wrap(tt::ttnn::LayoutAttr::get(
unwrap(ctx), static_cast<tt::ttnn::Layout>(layout)));
})
.def_property_readonly("value", [](tt::ttnn::LayoutAttr self) {
return static_cast<uint32_t>(self.getValue());
});
py::class_<tt::ttnn::TensorMemoryLayoutAttr>(m, "TensorMemoryLayoutAttr")
.def_static("get",
[](MlirContext ctx, uint32_t tensorMemoryLayout) {
return wrap(tt::ttnn::TensorMemoryLayoutAttr::get(
unwrap(ctx), static_cast<tt::ttnn::TensorMemoryLayout>(
tensorMemoryLayout)));
})
.def_property_readonly("value",
[](tt::ttnn::TensorMemoryLayoutAttr self) {
return static_cast<uint32_t>(self.getValue());
});
py::class_<tt::ttnn::BufferTypeAttr>(m, "BufferTypeAttr")
.def_static(
"get",
[](MlirContext ctx, uint32_t bufferType) {
return wrap(tt::ttnn::BufferTypeAttr::get(
unwrap(ctx), static_cast<tt::ttnn::BufferType>(bufferType)));
})
.def_property_readonly("value", [](tt::ttnn::BufferTypeAttr self) {
return static_cast<uint32_t>(self.getValue());
});
py::class_<tt::ttnn::MemoryConfigAttr>(m, "MemoryConfigAttr")
.def_static("get",
[](MlirContext ctx,
tt::ttnn::TensorMemoryLayoutAttr tensorMemoryLayoutAttr,
tt::ttnn::BufferTypeAttr bufferTypeAttr) {
return wrap(tt::ttnn::MemoryConfigAttr::get(
unwrap(ctx), tensorMemoryLayoutAttr, bufferTypeAttr));
})
.def_static(
"get_by_value",
[](MlirContext ctx, uint32_t tensorMemoryLayout,
uint32_t bufferType) {
return wrap(tt::ttnn::MemoryConfigAttr::get(
unwrap(ctx),
tt::ttnn::TensorMemoryLayoutAttr::get(
unwrap(ctx), static_cast<tt::ttnn::TensorMemoryLayout>(
tensorMemoryLayout)),
tt::ttnn::BufferTypeAttr::get(
unwrap(ctx),
static_cast<tt::ttnn::BufferType>(bufferType))));
})
.def_property_readonly("tensor_memory_layout",
&tt::ttnn::MemoryConfigAttr::getTensorMemoryLayout)
.def_property_readonly("buffer_type",
&tt::ttnn::MemoryConfigAttr::getBufferType);
py::class_<tt::ttnn::ShapeAttr>(m, "ShapeAttr")
.def_static("get",
[](MlirContext ctx, std::vector<int64_t> shape) {
return wrap(tt::ttnn::ShapeAttr::get(unwrap(ctx), shape));
})
.def_property_readonly("shape", [](tt::ttnn::ShapeAttr self) {
return std::vector<int64_t>(self.getShape().begin(),
self.getShape().end());
});
py::class_<tt::ttnn::MeshShapeAttr>(m, "MeshShapeAttr")
.def_static("get",
[](MlirContext ctx, int64_t y, int64_t x) {
return wrap(
tt::ttnn::MeshShapeAttr::get(unwrap(ctx), y, x));
})
.def_property_readonly("y", &tt::ttnn::MeshShapeAttr::getY)
.def_property_readonly("x", &tt::ttnn::MeshShapeAttr::getX);
}
} // namespace mlir::ttmlir::python
File renamed without changes.
10 changes: 10 additions & 0 deletions python/ttmlir/dialects/TTNNBinding.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#ifndef PYTHON_BINDINGS_TTMLIR_TTNNOPS
#define PYTHON_BINDINGS_TTMLIR_TTNNOPS

include "ttmlir/Dialect/TTNN/IR/TTNNOps.td"

#endif
10 changes: 10 additions & 0 deletions python/ttmlir/dialects/TTNNEnumBinding.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#ifndef PYTHON_BINDINGS_TTMLIR_TTNNENUMS
#define PYTHON_BINDINGS_TTMLIR_TTNNENUMS

include "ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td"

#endif
7 changes: 7 additions & 0 deletions python/ttmlir/dialects/ttnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from ._ttnn_ops_gen import *
from ._ttnn_enum_gen import *
from .._mlir_libs._ttmlir import register_dialect, ttnn_ir as ir

0 comments on commit cc5f60e

Please sign in to comment.