From 4430394fb50174c25da2156048f1ac4e62dce0a6 Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Mon, 12 Aug 2024 13:49:15 -0400 Subject: [PATCH] #331: Prohibit runtime mixing, strictly check runtime associated with data structures (#369) --- runtime/include/tt/runtime/types.h | 56 ++++++++++++++++++++++-------- runtime/lib/runtime.cpp | 14 ++++---- runtime/lib/ttmetal/runtime.cpp | 30 ++++++++-------- runtime/lib/ttnn/runtime.cpp | 17 +++++---- 4 files changed, 76 insertions(+), 41 deletions(-) diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index eca1e6474..bfb7e4ba5 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -5,6 +5,7 @@ #ifndef TT_RUNTIME_TYPES_H #define TT_RUNTIME_TYPES_H +#include #include #include #include @@ -15,8 +16,15 @@ namespace tt::runtime { +enum class DeviceRuntime { + Disabled, + TTNN, + TTMetal, +}; + namespace detail { struct ObjectImpl { + std::shared_ptr handle; ObjectImpl(std::shared_ptr handle) : handle(handle) {} @@ -25,14 +33,33 @@ struct ObjectImpl { return *static_cast(handle.get()); } }; -} // namespace detail -enum class DeviceRuntime { - Disabled, - TTNN, - TTMetal, +struct RuntimeCheckedObjectImpl { + std::shared_ptr handle; + ::tt::runtime::DeviceRuntime associatedRuntime; + + RuntimeCheckedObjectImpl(std::shared_ptr handle, + ::tt::runtime::DeviceRuntime runtime) + : handle(handle), associatedRuntime(runtime) {} + + bool matchesRuntime(DeviceRuntime runtime) const { + return associatedRuntime == runtime; + } + + template T &as(DeviceRuntime expectedRuntime) { + assert(associatedRuntime == expectedRuntime && + "Associated runtime does not match expected runtime of cast"); + return *static_cast(handle.get()); + } + template T const &as(DeviceRuntime expectedRuntime) const { + assert(associatedRuntime == expectedRuntime && + "Associated runtime does not match expected runtime of cast"); + return *static_cast(handle.get()); + } }; +} // namespace detail + struct TensorDesc { std::vector shape; std::vector stride; @@ -75,22 +102,23 @@ struct Binary : public Flatbuffer { std::vector getProgramOutputs(std::uint32_t programIndex) const; }; -struct Device : public detail::ObjectImpl { - using detail::ObjectImpl::ObjectImpl; +struct Device : public detail::RuntimeCheckedObjectImpl { + using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; - template static Device borrow(T &object) { - return Device(utils::unsafe_borrow_shared(&object)); + template static Device borrow(T &object, DeviceRuntime runtime) { + return Device(utils::unsafe_borrow_shared(&object), runtime); } }; -struct Event : public detail::ObjectImpl { - using detail::ObjectImpl::ObjectImpl; +struct Event : public detail::RuntimeCheckedObjectImpl { + using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; }; -struct Tensor : public detail::ObjectImpl { +struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr data; - Tensor(std::shared_ptr handle, std::shared_ptr data) - : detail::ObjectImpl(handle), data(data) {} + Tensor(std::shared_ptr handle, std::shared_ptr data, + DeviceRuntime runtime) + : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {} }; } // namespace tt::runtime diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 3e605434e..b4450952a 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -19,23 +19,23 @@ namespace tt::runtime { namespace detail { #if defined(TT_RUNTIME_ENABLE_TTNN) -DeviceRuntime currentRuntime = DeviceRuntime::TTNN; +DeviceRuntime globalCurrentRuntime = DeviceRuntime::TTNN; #elif defined(TT_RUNTIME_ENABLE_TTMETAL) -DeviceRuntime currentRuntime = DeviceRuntime::TTMetal; +DeviceRuntime globalCurrentRuntime = DeviceRuntime::TTMetal; #else -DeviceRuntime currentRuntime = DeviceRuntime::Disabled; +DeviceRuntime globalCurrentRuntime = DeviceRuntime::Disabled; #endif } // namespace detail DeviceRuntime getCurrentRuntime() { #if !defined(TT_RUNTIME_ENABLE_TTNN) - assert(detail::currentRuntime != DeviceRuntime::TTNN); + assert(detail::globalCurrentRuntime != DeviceRuntime::TTNN); #endif #if !defined(TT_RUNTIME_ENABLE_TTMETAL) - assert(detail::currentRuntime != DeviceRuntime::TTMetal); + assert(detail::globalCurrentRuntime != DeviceRuntime::TTMetal); #endif - return detail::currentRuntime; + return detail::globalCurrentRuntime; } std::vector getAvailableRuntimes() { @@ -56,7 +56,7 @@ void setCurrentRuntime(const DeviceRuntime &runtime) { #if !defined(TT_RUNTIME_ENABLE_TTMETAL) assert(runtime != DeviceRuntime::TTMetal); #endif - detail::currentRuntime = runtime; + detail::globalCurrentRuntime = runtime; } void setCompatibleRuntime(const Binary &binary) { diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 39fce01de..dd278bc7b 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -12,7 +12,7 @@ #include "ttmlir/Version.h" namespace tt::runtime::ttmetal { - +using ::tt::runtime::DeviceRuntime; constexpr inline std::size_t kHostBufferCommandQueueId = 0; using Events = std::vector>; using DeviceMesh = std::vector<::tt::tt_metal::Device *>; @@ -39,7 +39,8 @@ Tensor createTensor(std::shared_ptr data, desc.itemsize = itemsize; desc.dataType = dataType; std::shared_ptr tensor = std::make_shared(desc); - return Tensor(static_pointer_cast(tensor), data); + return Tensor(static_pointer_cast(tensor), data, + DeviceRuntime::TTMetal); } Device openDevice(std::vector const &deviceIds, @@ -52,11 +53,11 @@ Device openDevice(std::vector const &deviceIds, deviceMesh->push_back(CreateDevice(deviceId, num_hw_cqs)); ++i; } - return static_pointer_cast(deviceMesh); + return Device(static_pointer_cast(deviceMesh), DeviceRuntime::TTMetal); } void closeDevice(Device device) { - DeviceMesh &deviceMesh = device.as(); + DeviceMesh &deviceMesh = device.as(DeviceRuntime::TTMetal); for (::tt::tt_metal::Device *device : deviceMesh) { ::tt::tt_metal::CloseDevice(device); } @@ -111,8 +112,8 @@ Events maybeCopyHostOutputs(::tt::tt_metal::Device *device, Events copyEvents; int i = 0; for (Tensor const &outputHandle : outputHandles) { - if (TensorDesc const *hostTensor = - std::get_if(&outputHandle.as()); + if (TensorDesc const *hostTensor = std::get_if( + &outputHandle.as(DeviceRuntime::TTMetal)); hostTensor) { ::tt::tt_metal::CommandQueue &cq = device->command_queue(kHostBufferCommandQueueId); @@ -140,7 +141,7 @@ Event submit(Device deviceHandle, Binary executableHandle, ::tt::target::metal::TTMetalBinary const &fbb = *getBinary(executableHandle); ::tt::target::metal::Program const *program = fbb.programs()->Get(programIndex); - DeviceMesh &deviceMesh = deviceHandle.as(); + DeviceMesh &deviceMesh = deviceHandle.as(DeviceRuntime::TTMetal); assert(deviceMesh.size() == 1 && "Only one device is supported for now"); std::shared_ptr events = std::make_shared(); assert(program->device_programs()->size() == deviceMesh.size() && @@ -158,9 +159,9 @@ Event submit(Device deviceHandle, Binary executableHandle, for (unsigned i = 0; i < inputHandles.size(); ++i) { ::tt::target::TensorRef const *tensorRef = deviceProgram->inputs()->Get(i); - auto [buffer, event] = - prepareInput(device, inputHandles[i].as(), - inputHandles[i].data.get(), tensorRef); + auto [buffer, event] = prepareInput( + device, inputHandles[i].as(DeviceRuntime::TTMetal), + inputHandles[i].data.get(), tensorRef); inputs.emplace_back(deviceProgram->inputs()->Get(i)->global_id(), buffer, event); } @@ -172,8 +173,9 @@ Event submit(Device deviceHandle, Binary executableHandle, for (unsigned i = 0; i < outputHandles.size(); ++i) { ::tt::target::TensorRef const *tensorRef = deviceProgram->outputs()->Get(i); - std::shared_ptr<::tt::tt_metal::Buffer> buffer = - prepareOutput(device, &outputHandles[i].as(), tensorRef); + std::shared_ptr<::tt::tt_metal::Buffer> buffer = prepareOutput( + device, &outputHandles[i].as(DeviceRuntime::TTMetal), + tensorRef); outputs.emplace_back(deviceProgram->outputs()->Get(i)->global_id(), buffer); } @@ -195,11 +197,11 @@ Event submit(Device deviceHandle, Binary executableHandle, events->insert(events->end(), deviceEvents.begin(), deviceEvents.end()); } - return static_pointer_cast(events); + return Event(static_pointer_cast(events), DeviceRuntime::TTMetal); } void wait(Event event) { - Events events = event.as(); + Events events = event.as(DeviceRuntime::TTMetal); for (auto e : events) { ::tt::tt_metal::EventSynchronize(e); } diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 9f6f66202..4d647b806 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -10,6 +10,8 @@ namespace tt::runtime::ttnn { +using ::tt::runtime::DeviceRuntime; + template static BorrowedStorage createStorage(void *ptr, std::uint32_t numElements) { return BorrowedStorage( @@ -45,7 +47,7 @@ Tensor createTensor(std::shared_ptr data, auto tensor = std::make_shared<::ttnn::Tensor>( createStorage(data.get(), numElements, dataType), shape, utils::toTTNNDataType(dataType), ::ttnn::Layout::ROW_MAJOR); - return Tensor(tensor, data); + return Tensor(tensor, data, DeviceRuntime::TTNN); } Device openDevice(std::vector const &deviceIds, @@ -53,11 +55,11 @@ Device openDevice(std::vector const &deviceIds, assert(deviceIds.size() == 1 && "Only one device is supported for now"); assert(numHWCQs.empty() && "HWCQs are not supported for now"); auto &device = ::ttnn::open_device(deviceIds.front()); - return Device::borrow(device); + return Device::borrow(device, DeviceRuntime::TTNN); } void closeDevice(Device device) { - auto &ttnn_device = device.as<::ttnn::Device>(); + auto &ttnn_device = device.as<::ttnn::Device>(DeviceRuntime::TTNN); ::ttnn::close_device(ttnn_device); } @@ -74,25 +76,28 @@ Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, std::vector const &outputHandles) { - ::ttnn::Device &device = deviceHandle.as<::ttnn::Device>(); + ::ttnn::Device &device = deviceHandle.as<::ttnn::Device>(DeviceRuntime::TTNN); ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); std::vector<::ttnn::Tensor *> inputs; inputs.reserve(inputHandles.size()); for (auto &input : inputHandles) { + assert(input.matchesRuntime(DeviceRuntime::TTNN)); inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); } std::vector<::ttnn::Tensor *> outputs; outputs.reserve(outputHandles.size()); for (auto &output : outputHandles) { + assert(output.matchesRuntime(DeviceRuntime::TTNN)); outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); } tt::runtime::ttnn::runProgram(device, fbb.programs()->Get(programIndex), inputs, outputs); - return Event(nullptr); + return Event(nullptr, DeviceRuntime::TTNN); } -void wait(Event) { +void wait(Event event) { // Not implemented + assert(event.matchesRuntime(DeviceRuntime::TTNN)); } } // namespace tt::runtime::ttnn