Skip to content

Commit

Permalink
#1507: Added ability to gather golden result information and consolid…
Browse files Browse the repository at this point in the history
…ate into report as well as store golden and device artifacts generated during runtime (#1544)
  • Loading branch information
tapspatel authored Dec 11, 2024
1 parent 31e5518 commit 41b486c
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 15 deletions.
8 changes: 7 additions & 1 deletion runtime/include/tt/runtime/detail/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,19 @@ struct Hooks {
#endif
}

void unregisterHooks() const {
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
operatorCallback = std::nullopt;
#endif
}

private:
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
Hooks(std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
operatorCallback)
: operatorCallback(operatorCallback) {}

std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
mutable std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
operatorCallback;
#else
constexpr Hooks() = default;
Expand Down
77 changes: 74 additions & 3 deletions runtime/tools/python/ttrt/common/golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,27 @@
import shutil
import atexit
import re
from functools import partial

from ttrt.common.util import *


class GoldenRuntimeConfig:
def __init__(
self,
atol=1e-08,
rtol=1e-05,
pcc=0.99,
artifact_dir="",
save_golden_tensors=False,
):
self.artifact_dir = artifact_dir
self.pcc = pcc
self.atol = atol
self.rtol = rtol
self.save_golden_tensors = save_golden_tensors


def get_atol_rtol_pcc(golden, calculated):
import numpy as np
import torch
Expand Down Expand Up @@ -103,20 +120,33 @@ def get_pcc(golden, calculated):
)


def golden(binary, programContext, opContext):
def golden_partial_function(
golden_runtime_config, golden_results_data, binary, program_context, op_context
):
import torch
import ttrt.runtime
import ttrt.binary

print("-----------executing golden comparision-----------")

try:
loc = ttrt.runtime.get_op_loc_info(opContext)
op_debug_str = ttrt.runtime.get_op_debug_str(op_context)

# find matching golden tensor based on loc in op debug string
match = re.search(r"loc\(([^)]+)\)", op_debug_str)

if not match:
print(f"debug_str={op_debug_str}")
print("No location found in debug string - skipping golden comparison")
return

loc = match.group(1).replace('"', "")
print(f"found location={loc}")

op_golden_tensor = binary.get_debug_info_golden(loc)
op_output_tensor = ttrt.runtime.get_op_output_tensor(opContext, programContext)
op_output_tensor = ttrt.runtime.get_op_output_tensor(
op_context, program_context
)

if len(op_golden_tensor) == 0:
print("Golden tensor is empty - skipping golden comparison")
Expand All @@ -139,11 +169,52 @@ def golden(binary, programContext, opContext):
op_output_tensor, dtype=torch.float32
).flatten()

if golden_runtime_config.save_golden_tensors:
torch.save(
golden_tensor_torch,
f"{golden_runtime_config.artifact_dir}/{loc}_golden.pt",
)
torch.save(
output_tensor_torch,
f"{golden_runtime_config.artifact_dir}/{loc}_device.pt",
)

_, _, cal_pcc, output_str = get_atol_rtol_pcc(
golden_tensor_torch, output_tensor_torch
)

print(f"PCC={cal_pcc}")
print(output_str)

results = {}
results["expected_pcc"] = golden_runtime_config.pcc
results["actual_pcc"] = cal_pcc
results["atol"] = golden_runtime_config.atol
results["rtol"] = golden_runtime_config.rtol
results["allclose"] = torch.allclose(
golden_tensor_torch,
output_tensor_torch,
atol=golden_runtime_config.atol,
rtol=golden_runtime_config.rtol,
)
results["max"] = torch.max(
torch.abs(golden_tensor_torch - output_tensor_torch)
).item()
results["mean_absolute_error"] = torch.mean(
torch.abs(golden_tensor_torch - output_tensor_torch)
).item()
results["root_mean_square_error"] = torch.sqrt(
torch.mean((golden_tensor_torch - output_tensor_torch) ** 2)
).item()
results["cosine_similarity"] = torch.nn.functional.cosine_similarity(
golden_tensor_torch.unsqueeze(0), output_tensor_torch.unsqueeze(0)
).item()

golden_results_data[loc] = results

finally:
print("-----------finished executing golden comparision-----------")


def get_golden_fn(golden_runtime_config, golden_results_data):
return partial(golden_partial_function, golden_runtime_config, golden_results_data)
58 changes: 54 additions & 4 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ttrt.common.util import *
from ttrt.common.query import Query
from ttrt.common.golden import golden
from ttrt.common.golden import get_golden_fn, GoldenRuntimeConfig


class Run:
Expand Down Expand Up @@ -103,6 +103,13 @@ def initialize_api():
choices=None,
help="atol for golden test",
)
Run.register_arg(
name="--pcc",
type=float,
default=0.99,
choices=None,
help="pcc for golden test",
)
Run.register_arg(
name="--seed",
type=int,
Expand Down Expand Up @@ -159,6 +166,13 @@ def initialize_api():
choices=[True, False],
help="run golden comparison for intermediate and output tensors",
)
Run.register_arg(
name="--save-golden-tensors",
type=bool,
default=False,
choices=[True, False],
help="save golden and device tensors that are compared during callback runtime",
)
Run.register_arg(
name="binary",
type=str,
Expand Down Expand Up @@ -348,9 +362,6 @@ def _execute(binaries):
self.logging.warning(f"no binaries found to run - returning early")
return

if self["--golden"]:
callback_env = ttrt.runtime.DebugHooks.get(golden)

debug_env = ttrt.runtime.DebugEnv.get(
self["--load-kernels-from-disk"], self["--enable-async-ttnn"]
)
Expand All @@ -373,6 +384,9 @@ def _execute(binaries):
try:
self.logging.info(f"evaluating binary={bin.file_path}")

if self["--save-artifacts"]:
self.artifacts.create_binary_artifacts_folder(bin)

program_indices = []
if self["--program-index"] == "all":
program_indices.extend(range(bin.get_num_programs()))
Expand Down Expand Up @@ -440,6 +454,20 @@ def _execute(binaries):
total_outputs.append(outputs)

event = None
golden_results_data = {}
if self["--golden"]:
callback_env = ttrt.runtime.DebugHooks.get(
get_golden_fn(
GoldenRuntimeConfig(
self["--atol"],
self["--rtol"],
self["--pcc"],
f"{self.artifacts.get_binary_folder_path(bin)}/run/program_{program_index}",
self["--save-golden-tensors"],
),
golden_results_data,
)
)
for loop in range(self["--loops"]):
self.logging.debug(
f"starting loop={loop+1}/{self['--loops']} for binary={bin.file_path}"
Expand Down Expand Up @@ -519,6 +547,28 @@ def _execute(binaries):
self.logging.debug(f"{tensor}\n")

device.deallocate_buffers()

# if golden comparison is enabled, check golden results json file to see if test passed
if self["--golden"]:
if self["--save-artifacts"]:
golden_results_file_path = f"{self.artifacts.get_binary_folder_path(bin)}/run/program_{program_index}/golden_results.json"

with open(
golden_results_file_path, "w"
) as json_file:
json.dump(
golden_results_data, json_file, indent=4
)

for loc, golden_data in golden_results_data.items():
if (
golden_data["actual_pcc"]
< golden_data["expected_pcc"]
):
raise Exception(
f"Failed: golden comparison failed for program={program_index}, actual_pcc={golden_data['actual_pcc']} < expected_pcc={golden_data['expected_pcc']}"
)

except Exception as e:
test_result = {
"file_path": bin.file_path,
Expand Down
19 changes: 12 additions & 7 deletions runtime/tools/python/ttrt/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,16 +409,22 @@ def clean_artifacts(self):
def clean_binary_artifacts(self, binary):
self.file_manager.remove_directory(self.get_binary_folder_path(binary))

def create_binary_artifacts_folder(self, binary):
binary_folder = self.get_binary_folder_path(binary)
self.file_manager.create_directory(binary_folder)
self.file_manager.create_directory(f"{binary_folder}/run")
self.file_manager.create_directory(f"{binary_folder}/perf")

for program in binary.programs:
program_folder = f"{binary_folder}/run/program_{program.index}"
self.file_manager.create_directory(program_folder)

def save_binary(self, binary, query=None):
binary_folder = self.get_binary_folder_path(binary)

self.logging.info(
f"saving binary={binary.file_path} to binary_folder={binary_folder}"
)
self.file_manager.create_directory(binary_folder)
self.file_manager.create_directory(f"{binary_folder}/run")
self.file_manager.create_directory(f"{binary_folder}/perf")

self.file_manager.copy_file(f"{binary_folder}", binary.file_path)

for program in binary.programs:
Expand All @@ -427,20 +433,19 @@ def save_binary(self, binary, query=None):
self.logging.info(
f"saving program={program.index} for binary={binary.file_path} to program_folder={program_folder}"
)
self.file_manager.create_directory(program_folder)

for i in range(len(program.input_tensors)):
self.save_torch_tensor(
program_folder,
program.input_tensors[i],
f"program_{program.index}_input_{i}.pt",
f"input_{i}.pt",
)

for i in range(len(program.output_tensors)):
self.save_torch_tensor(
program_folder,
program.output_tensors[i],
f"program_{program.index}_output_{i}.pt",
f"output_{i}.pt",
)

if query != None:
Expand Down
8 changes: 8 additions & 0 deletions runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,12 @@ PYBIND11_MODULE(_C, m) {
&tt::runtime::ttnn::test::getHostRowMajorLayout, py::arg("dtype"),
"Get host row major layout");
#endif

/**
* Cleanup code to force a well ordered destruction w.r.t. the GIL
*/
auto cleanup_callback = []() {
::tt::runtime::debug::Hooks::get().unregisterHooks();
};
m.add_object("_cleanup", py::capsule(cleanup_callback));
}

0 comments on commit 41b486c

Please sign in to comment.