Skip to content

Commit

Permalink
Added LOC info to TTRT Perf output (#1401)
Browse files Browse the repository at this point in the history
* Added LOC info to TTRT Perf output

* Rebase fixes

* Non-Runtime build fix

* Readability fixes

* Fixed row removal in perf results

* Added LocInfo
  • Loading branch information
vprajapati-tt authored Dec 2, 2024
1 parent 12fc71c commit 139c7cc
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 70 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ union OpType {
table Operation {
type: OpType;
debug_info: string;
loc_info: string;
}

table Program {
Expand Down
10 changes: 9 additions & 1 deletion include/ttmlir/Target/Utils/FuncOpToProgram.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ inline std::string getOpDebugString(mlir::Operation *op,
return str;
};

inline std::string getOpLocInfo(mlir::Operation *op) {
std::string str;
llvm::raw_string_ostream os(str);
op->getLoc().print(os);
return str;
}

inline Value getOperandThroughDPSOps(Value value) {
auto *op = value.getDefiningOp();
if (!op) {
Expand Down Expand Up @@ -76,7 +83,8 @@ Program<OpT> funcOpToProgram(FlatbufferObjectCache &cache, func::FuncOp entry,
}
} else {
std::string debugStr = getOpDebugString(op, printFlags);
program.ops.push_back(fn(cache, op, debugStr));
std::string locInfo = getOpLocInfo(op);
program.ops.push_back(fn(cache, op, debugStr, locInfo));
}
});

Expand Down
6 changes: 5 additions & 1 deletion include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,11 @@ toDebugInfo(::flatbuffers::FlatBufferBuilder &fbb, std::string const &name,
ModuleOp module) {
std::string source;
llvm::raw_string_ostream os(source);
module->print(os);

mlir::OpPrintingFlags flags;
flags.enableDebugInfo(); // Enable the loc dumping
module->print(os, flags);

return ::tt::target::CreateMLIRDirect(fbb, name.c_str(), source.c_str());
}
} // namespace mlir::tt
Expand Down
181 changes: 113 additions & 68 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ void wait(Event event);

std::string getOpDebugString(OpContext opContextHandle);

std::string getOpLocInfo(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

Expand Down
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ void wait(Event event);

std::string getOpDebugString(OpContext opContextHandle);

std::string getOpLocInfo(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

Expand Down
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ void wait(Event event);

std::string getOpDebugString(OpContext opContextHandle);

std::string getOpLocInfo(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

Expand Down
15 changes: 15 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,21 @@ std::string getOpDebugString(OpContext opContextHandle) {
throw std::runtime_error("runtime is not enabled");
}

std::string getOpLocInfo(OpContext opContextHandle) {
#ifdef TT_RUNTIME_ENABLE_TTNN
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getOpLocInfo(opContextHandle);
}
#endif

#ifdef TT_RUNTIME_ENABLE_TTMETAL
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getOpLocInfo(opContextHandle);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
Expand Down
6 changes: 6 additions & 0 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ std::string getOpDebugString(OpContext opContextHandle) {
return "";
}

std::string getOpLocInfo(OpContext opContextHandle) {
// Not implemented
LOG_WARNING("obtaining op location info for metal runtime not implemented");
return "";
}

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
// Not implemented
Expand Down
11 changes: 11 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,19 @@
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTNN/program_generated.h"

#ifdef TT_RUNTIME_ENABLE_PERF_TRACE
#include "tracy/Tracy.hpp"
#endif

namespace tt::runtime::ttnn {
using LogType = ::tt::runtime::logger::LogType;

void tracyLogOpLocation(const ::tt::target::ttnn::Operation *op) {
#ifdef TT_RUNTIME_ENABLE_PERF_TRACE
TracyMessage(op->loc_info()->c_str(), op->loc_info()->size());
#endif
}

static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) {
bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier(
binary.handle.get());
Expand Down Expand Up @@ -74,6 +84,7 @@ class ProgramExecutor {
for (const ::tt::target::ttnn::Operation *op : *program->operations()) {
LOG_DEBUG(LogType::LogRuntimeTTNN,
"Executing operation: ", op->debug_info()->c_str());
tracyLogOpLocation(op);
runOperation(op);
runCallback(executableHandle, op, &context);
}
Expand Down
6 changes: 6 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ std::string getOpDebugString(OpContext opContextHandle) {
return std::string(opContext.debug_info()->c_str());
}

std::string getOpLocInfo(OpContext opContextHandle) {
auto const &opContext =
opContextHandle.as<::tt::target::ttnn::Operation>(DeviceRuntime::TTNN);
return std::string(opContext.loc_info()->c_str());
}

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
auto const &programContext =
Expand Down
37 changes: 37 additions & 0 deletions runtime/tools/python/ttrt/common/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
import atexit
import traceback
from pathlib import Path
import csv

from ttrt.common.util import *
from ttrt.common.query import Query


def get_loc_data_hook(binary, programContext, opContext):
op_debug_str = ttrt.runtime.get_op_debug_str(opContext)


class Perf:
registered_args = {}

Expand Down Expand Up @@ -456,6 +461,38 @@ def signal_handler(sig, frame):
)

process_ops(None, None, False)

# Add post-processing steps to insert location data into the ops_perf data file
with open(profiler_csv_file_path, "r") as perf_file:
perf_reader = csv.DictReader(perf_file)
headers = list(perf_reader.fieldnames) + ["LOC"]
perf_data = list(perf_reader)

with open(profiler_csv_file_path, "w+") as perf_file, open(
tracy_ops_data_file_path, "r"
) as message_file:
message_reader = csv.reader(message_file, delimiter=";")
ops_index = 0
prev = None
for message in message_reader:
message = message[0] # Don't need timestamp information
if message.startswith("`"):
# This is a TTNN Message
# The location data is now in the previous message
# The order of data is maintained in perf_data so as the messages are received, they update the id last encountered.
# Now that we have a new message, we can update the location data from the previous message
if prev:
# Get the location data from the previous message and add it as new data for the perf_data (as a new col)
if len(perf_data) > ops_index:
perf_data[ops_index]["LOC"] = prev
ops_index += 1
else:
prev = message
perf_writer = csv.DictWriter(perf_file, fieldnames=headers)
perf_writer.writeheader()
for row in perf_data:
perf_writer.writerow(row)

self.file_manager.copy_file(
perf_folder_path,
profiler_csv_file_path,
Expand Down
2 changes: 2 additions & 0 deletions runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ PYBIND11_MODULE(_C, m) {
"Get the input tensor of the op");
m.def("get_op_debug_str", &tt::runtime::getOpDebugString,
"Get the debug string of the op");
m.def("get_op_loc_info", &tt::runtime::getOpLocInfo,
"Get the location info of the op");

py::class_<tt::runtime::debug::Env>(m, "DebugEnv")
.def_static("get", &tt::runtime::debug::Env::get)
Expand Down
1 change: 1 addition & 0 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ set(TTMETAL_INCLUDE_DIRS
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc/${ARCH_EXTRA_DIR}
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd/src/firmware/riscv/${ARCH_NAME}
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_eager
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal-build/include
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/reflect/e75434c4c5f669e4a74e4d84e0a30d7249c1e66f
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/nanomsg/28cc32d5bdb6a858fe53b3ccf7e923957e53eada/include
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/fmt/73b5ec45edbd92babfd91c3777a9e1ab9cac8238/include
Expand Down

0 comments on commit 139c7cc

Please sign in to comment.