Skip to content

Commit

Permalink
#331: Prohibit runtime mixing, strictly check runtime associated with…
Browse files Browse the repository at this point in the history
… data structures (#369)
  • Loading branch information
jnie-TT authored Aug 12, 2024
1 parent 6358122 commit 4430394
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 41 deletions.
56 changes: 42 additions & 14 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef TT_RUNTIME_TYPES_H
#define TT_RUNTIME_TYPES_H

#include <cassert>
#include <memory>
#include <string_view>
#include <vector>
Expand All @@ -15,8 +16,15 @@

namespace tt::runtime {

enum class DeviceRuntime {
Disabled,
TTNN,
TTMetal,
};

namespace detail {
struct ObjectImpl {

std::shared_ptr<void> handle;

ObjectImpl(std::shared_ptr<void> handle) : handle(handle) {}
Expand All @@ -25,14 +33,33 @@ struct ObjectImpl {
return *static_cast<T const *>(handle.get());
}
};
} // namespace detail

enum class DeviceRuntime {
Disabled,
TTNN,
TTMetal,
struct RuntimeCheckedObjectImpl {
std::shared_ptr<void> handle;
::tt::runtime::DeviceRuntime associatedRuntime;

RuntimeCheckedObjectImpl(std::shared_ptr<void> handle,
::tt::runtime::DeviceRuntime runtime)
: handle(handle), associatedRuntime(runtime) {}

bool matchesRuntime(DeviceRuntime runtime) const {
return associatedRuntime == runtime;
}

template <typename T> T &as(DeviceRuntime expectedRuntime) {
assert(associatedRuntime == expectedRuntime &&
"Associated runtime does not match expected runtime of cast");
return *static_cast<T *>(handle.get());
}
template <typename T> T const &as(DeviceRuntime expectedRuntime) const {
assert(associatedRuntime == expectedRuntime &&
"Associated runtime does not match expected runtime of cast");
return *static_cast<T const *>(handle.get());
}
};

} // namespace detail

struct TensorDesc {
std::vector<std::uint32_t> shape;
std::vector<std::uint32_t> stride;
Expand Down Expand Up @@ -75,22 +102,23 @@ struct Binary : public Flatbuffer {
std::vector<TensorDesc> getProgramOutputs(std::uint32_t programIndex) const;
};

struct Device : public detail::ObjectImpl {
using detail::ObjectImpl::ObjectImpl;
struct Device : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;

template <typename T> static Device borrow(T &object) {
return Device(utils::unsafe_borrow_shared(&object));
template <typename T> static Device borrow(T &object, DeviceRuntime runtime) {
return Device(utils::unsafe_borrow_shared(&object), runtime);
}
};

struct Event : public detail::ObjectImpl {
using detail::ObjectImpl::ObjectImpl;
struct Event : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
};

struct Tensor : public detail::ObjectImpl {
struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> data;
Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data)
: detail::ObjectImpl(handle), data(data) {}
Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {}
};

} // namespace tt::runtime
Expand Down
14 changes: 7 additions & 7 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ namespace tt::runtime {

namespace detail {
#if defined(TT_RUNTIME_ENABLE_TTNN)
DeviceRuntime currentRuntime = DeviceRuntime::TTNN;
DeviceRuntime globalCurrentRuntime = DeviceRuntime::TTNN;
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
DeviceRuntime currentRuntime = DeviceRuntime::TTMetal;
DeviceRuntime globalCurrentRuntime = DeviceRuntime::TTMetal;
#else
DeviceRuntime currentRuntime = DeviceRuntime::Disabled;
DeviceRuntime globalCurrentRuntime = DeviceRuntime::Disabled;
#endif

} // namespace detail

DeviceRuntime getCurrentRuntime() {
#if !defined(TT_RUNTIME_ENABLE_TTNN)
assert(detail::currentRuntime != DeviceRuntime::TTNN);
assert(detail::globalCurrentRuntime != DeviceRuntime::TTNN);
#endif
#if !defined(TT_RUNTIME_ENABLE_TTMETAL)
assert(detail::currentRuntime != DeviceRuntime::TTMetal);
assert(detail::globalCurrentRuntime != DeviceRuntime::TTMetal);
#endif
return detail::currentRuntime;
return detail::globalCurrentRuntime;
}

std::vector<DeviceRuntime> getAvailableRuntimes() {
Expand All @@ -56,7 +56,7 @@ void setCurrentRuntime(const DeviceRuntime &runtime) {
#if !defined(TT_RUNTIME_ENABLE_TTMETAL)
assert(runtime != DeviceRuntime::TTMetal);
#endif
detail::currentRuntime = runtime;
detail::globalCurrentRuntime = runtime;
}

void setCompatibleRuntime(const Binary &binary) {
Expand Down
30 changes: 16 additions & 14 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "ttmlir/Version.h"

namespace tt::runtime::ttmetal {

using ::tt::runtime::DeviceRuntime;
constexpr inline std::size_t kHostBufferCommandQueueId = 0;
using Events = std::vector<std::shared_ptr<::tt::tt_metal::Event>>;
using DeviceMesh = std::vector<::tt::tt_metal::Device *>;
Expand All @@ -39,7 +39,8 @@ Tensor createTensor(std::shared_ptr<void> data,
desc.itemsize = itemsize;
desc.dataType = dataType;
std::shared_ptr<MetalTensor> tensor = std::make_shared<MetalTensor>(desc);
return Tensor(static_pointer_cast<void>(tensor), data);
return Tensor(static_pointer_cast<void>(tensor), data,
DeviceRuntime::TTMetal);
}

Device openDevice(std::vector<int> const &deviceIds,
Expand All @@ -52,11 +53,11 @@ Device openDevice(std::vector<int> const &deviceIds,
deviceMesh->push_back(CreateDevice(deviceId, num_hw_cqs));
++i;
}
return static_pointer_cast<void>(deviceMesh);
return Device(static_pointer_cast<void>(deviceMesh), DeviceRuntime::TTMetal);
}

void closeDevice(Device device) {
DeviceMesh &deviceMesh = device.as<DeviceMesh>();
DeviceMesh &deviceMesh = device.as<DeviceMesh>(DeviceRuntime::TTMetal);
for (::tt::tt_metal::Device *device : deviceMesh) {
::tt::tt_metal::CloseDevice(device);
}
Expand Down Expand Up @@ -111,8 +112,8 @@ Events maybeCopyHostOutputs(::tt::tt_metal::Device *device,
Events copyEvents;
int i = 0;
for (Tensor const &outputHandle : outputHandles) {
if (TensorDesc const *hostTensor =
std::get_if<TensorDesc>(&outputHandle.as<MetalTensor>());
if (TensorDesc const *hostTensor = std::get_if<TensorDesc>(
&outputHandle.as<MetalTensor>(DeviceRuntime::TTMetal));
hostTensor) {
::tt::tt_metal::CommandQueue &cq =
device->command_queue(kHostBufferCommandQueueId);
Expand Down Expand Up @@ -140,7 +141,7 @@ Event submit(Device deviceHandle, Binary executableHandle,
::tt::target::metal::TTMetalBinary const &fbb = *getBinary(executableHandle);
::tt::target::metal::Program const *program =
fbb.programs()->Get(programIndex);
DeviceMesh &deviceMesh = deviceHandle.as<DeviceMesh>();
DeviceMesh &deviceMesh = deviceHandle.as<DeviceMesh>(DeviceRuntime::TTMetal);
assert(deviceMesh.size() == 1 && "Only one device is supported for now");
std::shared_ptr<Events> events = std::make_shared<Events>();
assert(program->device_programs()->size() == deviceMesh.size() &&
Expand All @@ -158,9 +159,9 @@ Event submit(Device deviceHandle, Binary executableHandle,
for (unsigned i = 0; i < inputHandles.size(); ++i) {
::tt::target::TensorRef const *tensorRef =
deviceProgram->inputs()->Get(i);
auto [buffer, event] =
prepareInput(device, inputHandles[i].as<MetalTensor>(),
inputHandles[i].data.get(), tensorRef);
auto [buffer, event] = prepareInput(
device, inputHandles[i].as<MetalTensor>(DeviceRuntime::TTMetal),
inputHandles[i].data.get(), tensorRef);
inputs.emplace_back(deviceProgram->inputs()->Get(i)->global_id(), buffer,
event);
}
Expand All @@ -172,8 +173,9 @@ Event submit(Device deviceHandle, Binary executableHandle,
for (unsigned i = 0; i < outputHandles.size(); ++i) {
::tt::target::TensorRef const *tensorRef =
deviceProgram->outputs()->Get(i);
std::shared_ptr<::tt::tt_metal::Buffer> buffer =
prepareOutput(device, &outputHandles[i].as<MetalTensor>(), tensorRef);
std::shared_ptr<::tt::tt_metal::Buffer> buffer = prepareOutput(
device, &outputHandles[i].as<MetalTensor>(DeviceRuntime::TTMetal),
tensorRef);
outputs.emplace_back(deviceProgram->outputs()->Get(i)->global_id(),
buffer);
}
Expand All @@ -195,11 +197,11 @@ Event submit(Device deviceHandle, Binary executableHandle,
events->insert(events->end(), deviceEvents.begin(), deviceEvents.end());
}

return static_pointer_cast<void>(events);
return Event(static_pointer_cast<void>(events), DeviceRuntime::TTMetal);
}

void wait(Event event) {
Events events = event.as<Events>();
Events events = event.as<Events>(DeviceRuntime::TTMetal);
for (auto e : events) {
::tt::tt_metal::EventSynchronize(e);
}
Expand Down
17 changes: 11 additions & 6 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

namespace tt::runtime::ttnn {

using ::tt::runtime::DeviceRuntime;

template <typename T>
static BorrowedStorage createStorage(void *ptr, std::uint32_t numElements) {
return BorrowedStorage(
Expand Down Expand Up @@ -45,19 +47,19 @@ Tensor createTensor(std::shared_ptr<void> data,
auto tensor = std::make_shared<::ttnn::Tensor>(
createStorage(data.get(), numElements, dataType), shape,
utils::toTTNNDataType(dataType), ::ttnn::Layout::ROW_MAJOR);
return Tensor(tensor, data);
return Tensor(tensor, data, DeviceRuntime::TTNN);
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
assert(deviceIds.size() == 1 && "Only one device is supported for now");
assert(numHWCQs.empty() && "HWCQs are not supported for now");
auto &device = ::ttnn::open_device(deviceIds.front());
return Device::borrow(device);
return Device::borrow(device, DeviceRuntime::TTNN);
}

void closeDevice(Device device) {
auto &ttnn_device = device.as<::ttnn::Device>();
auto &ttnn_device = device.as<::ttnn::Device>(DeviceRuntime::TTNN);
::ttnn::close_device(ttnn_device);
}

Expand All @@ -74,25 +76,28 @@ Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles,
std::vector<Tensor> const &outputHandles) {
::ttnn::Device &device = deviceHandle.as<::ttnn::Device>();
::ttnn::Device &device = deviceHandle.as<::ttnn::Device>(DeviceRuntime::TTNN);
::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle);
std::vector<::ttnn::Tensor *> inputs;
inputs.reserve(inputHandles.size());
for (auto &input : inputHandles) {
assert(input.matchesRuntime(DeviceRuntime::TTNN));
inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get()));
}
std::vector<::ttnn::Tensor *> outputs;
outputs.reserve(outputHandles.size());
for (auto &output : outputHandles) {
assert(output.matchesRuntime(DeviceRuntime::TTNN));
outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get()));
}
tt::runtime::ttnn::runProgram(device, fbb.programs()->Get(programIndex),
inputs, outputs);
return Event(nullptr);
return Event(nullptr, DeviceRuntime::TTNN);
}

void wait(Event) {
void wait(Event event) {
// Not implemented
assert(event.matchesRuntime(DeviceRuntime::TTNN));
}

} // namespace tt::runtime::ttnn

0 comments on commit 4430394

Please sign in to comment.