Skip to content

Commit

Permalink
#299: Support compiling both runtimes and toggling runtime type.
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Aug 6, 2024
1 parent 2bdb825 commit 18405ee
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 49 deletions.
5 changes: 0 additions & 5 deletions runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@ option(TTMLIR_ENABLE_RUNTIME_TESTS "Enable runtime tests" OFF)
option(TT_RUNTIME_ENABLE_TTNN "Enable TTNN Runtime" OFF)
option(TT_RUNTIME_ENABLE_TTMETAL "Enable TTMetal Runtime" OFF)

if (TT_RUNTIME_ENABLE_TTNN AND TT_RUNTIME_ENABLE_TTMETAL)
message(FATAL_ERROR "Cannot enable both TTNN and TTMETAL runtimes")
endif()

if (NOT TT_RUNTIME_ENABLE_TTNN AND NOT TT_RUNTIME_ENABLE_TTMETAL)
# Default to TTNN
set(TT_RUNTIME_ENABLE_TTNN ON)
endif()


add_subdirectory(lib)
add_subdirectory(tools)
if (TTMLIR_ENABLE_RUNTIME_TESTS)
Expand Down
14 changes: 14 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@

namespace tt::runtime {

namespace detail {
static DeviceRuntime currentRuntime = DeviceRuntime::TTNN;
} // namespace detail

inline const DeviceRuntime &getCurrentRuntime() {
return detail::currentRuntime;
}

inline void setCurrentRuntime(const DeviceRuntime &runtime) {
detail::currentRuntime = runtime;
}

void setCompatibleRuntime(const Binary &binary);

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();

Tensor createTensor(std::shared_ptr<void> data,
Expand Down
5 changes: 5 additions & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ struct ObjectImpl {
};
} // namespace detail

enum class DeviceRuntime {
TTNN,
TTMetal,
};

struct TensorDesc {
std::vector<std::uint32_t> shape;
std::vector<std::uint32_t> stride;
Expand Down
107 changes: 71 additions & 36 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,40 @@
#include "tt/runtime/utils.h"
#include "ttmlir/Version.h"

#if defined(TT_RUNTIME_ENABLE_TTNN) && defined(TT_RUNTIME_ENABLE_TTMETAL)
#error \
"Only one of TT_RUNTIME_ENABLE_TTNN and TT_RUNTIME_ENABLE_TTMETAL can be defined"
#endif

#if defined(TT_RUNTIME_ENABLE_TTNN)
#include "tt/runtime/detail/ttnn.h"
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
#include "tt/runtime/detail/ttmetal.h"
#endif

namespace tt::runtime {

void setCompatibleRuntime(const Binary &binary) {
std::string_view fileIdentifier = binary.getFileIdentifier();
if (fileIdentifier == "TTNN") {
setCurrentRuntime(DeviceRuntime::TTNN);
} else if (fileIdentifier == "TTM0") {
setCurrentRuntime(DeviceRuntime::TTMetal);
} else {
throw std::runtime_error("Unsupported runtime binary file identifier");
}
}

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::getCurrentSystemDesc();
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::getCurrentSystemDesc();
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getCurrentSystemDesc();
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getCurrentSystemDesc();
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Tensor createTensor(std::shared_ptr<void> data,
Expand All @@ -36,51 +50,72 @@ Tensor createTensor(std::shared_ptr<void> data,
assert(not stride.empty());
assert(itemsize > 0);
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::createTensor(data, shape, stride, itemsize,
dataType);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::createTensor(data, shape, stride, itemsize,
dataType);
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::createTensor(data, shape, stride, itemsize,
dataType);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::createTensor(data, shape, stride, itemsize,
dataType);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs);
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

void closeDevice(Device device) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::closeDevice(device);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::closeDevice(device);
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::closeDevice(device);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::closeDevice(device);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles,
std::vector<Tensor> const &outputHandles) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles, outputHandles);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

void wait(Event) { throw std::runtime_error("Not implemented"); }
Expand Down
1 change: 1 addition & 0 deletions runtime/test/ttnn/test_subtract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ TEST(TTNNSubtract, Equal) {
assert(fbPath && "Path to subtract flatbuffer must be provided");
::tt::runtime::Binary fbb = ::tt::runtime::Binary::loadFromPath(fbPath);
EXPECT_EQ(fbb.getFileIdentifier(), "TTNN");
::tt::runtime::setCompatibleRuntime(fbb);
std::vector<::tt::runtime::TensorDesc> inputDescs = fbb.getProgramInputs(0);
std::vector<::tt::runtime::TensorDesc> outputDescs = fbb.getProgramOutputs(0);
std::vector<::tt::runtime::Tensor> inputTensors, outputTensors;
Expand Down
14 changes: 7 additions & 7 deletions runtime/tools/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@
]

dylibs = []
linklibs = []
linklibs = ["TTBinary"]
if enable_ttnn:
dylibs = ["_ttnn.so"]
linklibs = ["TTRuntimeTTNN", ":_ttnn.so"]
elif enable_ttmetal:
assert enable_ttmetal
dylibs = ["libtt_metal.so"]
linklibs = ["TTRuntimeTTMetal", "tt_metal"]
dylibs += ["_ttnn.so"]
linklibs += ["TTRuntimeTTNN", ":_ttnn.so"]

if enable_ttmetal:
dylibs += ["libtt_metal.so"]
linklibs += ["TTRuntimeTTMetal", "tt_metal"]

if enable_runtime:
assert enable_ttmetal or enable_ttnn, "At least one runtime must be enabled"
Expand Down
2 changes: 1 addition & 1 deletion runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def run(args):
torch.manual_seed(args.seed)

for (binary_name, fbb, fbb_dict, program_indices) in fbb_list:
ttrt.runtime.set_compatible_runtime(fbb)
torch_inputs[binary_name] = {}
torch_outputs[binary_name] = {}

for program_index in program_indices:
torch_inputs[binary_name][program_index] = []
torch_outputs[binary_name][program_index] = []
Expand Down
3 changes: 3 additions & 0 deletions runtime/tools/python/ttrt/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
Event,
Tensor,
DataType,
DeviceRuntime,
get_current_runtime,
set_compatible_runtime,
get_current_system_desc,
open_device,
close_device,
Expand Down
8 changes: 8 additions & 0 deletions runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ PYBIND11_MODULE(_C, m) {
.value("UInt32", ::tt::target::DataType::UInt32)
.value("UInt16", ::tt::target::DataType::UInt16)
.value("UInt8", ::tt::target::DataType::UInt8);
py::enum_<::tt::runtime::DeviceRuntime>(m, "DeviceRuntime")
.value("TTNN", ::tt::runtime::DeviceRuntime::TTNN)
.value("TTMetal", ::tt::runtime::DeviceRuntime::TTMetal);

m.def("get_current_runtime", &tt::runtime::getCurrentRuntime,
"Get the backend device runtime type");
m.def("set_compatible_runtime", &tt::runtime::setCompatibleRuntime,
py::arg("binary"),
"Set the backend device runtime type to match the binary");
m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc,
"Get the current system descriptor");
m.def(
Expand Down

0 comments on commit 18405ee

Please sign in to comment.