Skip to content

Commit

Permalink
PyBind GoldenTensor w/ Integral Types
Browse files Browse the repository at this point in the history
* PyBind `GoldenTensor` w/ Integral Types

This change introduces a PyBound class `GoldenTensor` that reflects the
same type from TTRT. This binding supports the following element types
for a `GoldenTensor`:
  - uint8
  - uint16
  - uint32
  - float32

Also, this change reverts PR #1488 , as it broke some things unintentionally. This is to be able to address these problems in a separate PR. 

Closes #1334
  • Loading branch information
ctodTT authored Dec 11, 2024
1 parent 9b398a3 commit 8bc14e2
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 70 deletions.
9 changes: 4 additions & 5 deletions python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,30 +290,29 @@ def eltwise_proxy(
while len(stack) > 0 and stack[0].filename == cur_filename:
stack = stack[1:]

id = self.get_next_global_id()
loc = Location.file(stack[0].filename, stack[0].lineno, id)

assert (
len(stack) > 0
), "Top of callstack to builder funcs must be outside this file"

with self._ctx, self._loc:
output = self.empty(self.get_shape(inputs[0]))

id = self.get_next_global_id()

op = op_ttir_function(
[self._get_type(output)],
inputs,
[output],
self._get_operand_constraint_attr(3),
loc=loc,
loc=Location.name(str(id)),
)

goldens = []
for input in inputs:
goldens.append(self._get_golden_tensor(input))

golden = Golden(op_golden_function(*goldens))
self.id_golden_map[str(loc)] = golden
self.id_golden_map[str(id)] = golden
self._store_golden(op, golden)
self._override_golden(output, golden)

Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct Hooks {

mutable std::optional<std::function<void(Binary, CallbackContext, OpContext)>>
operatorCallback;

#else
constexpr Hooks() = default;
#endif
Expand Down
1 change: 1 addition & 0 deletions runtime/tools/python/ttrt/binary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
load_binary_from_capsule,
load_system_desc_from_path,
Flatbuffer,
GoldenTensor,
)
from . import stats

Expand Down
90 changes: 74 additions & 16 deletions runtime/tools/python/ttrt/binary/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,8 @@ PYBIND11_MODULE(_C, m) {
&tt::runtime::Binary::getFileIdentifier)
.def("as_json", &tt::runtime::Binary::asJson)
.def("store", &tt::runtime::Binary::store)
.def("get_debug_info_golden", [](tt::runtime::Binary &binary,
std::string &loc) {
const ::tt::target::GoldenTensor *goldenTensor =
binary.getDebugInfoGolden(loc);
if (goldenTensor == nullptr) {
return std::vector<float>();
}

int totalDataSize = std::accumulate((*goldenTensor->shape()).begin(),
(*goldenTensor->shape()).end(), 1,
std::multiplies<int64_t>());
std::vector<float> dataVec(totalDataSize);
std::memcpy(dataVec.data(), goldenTensor->data()->data(),
totalDataSize * sizeof(float));
return dataVec;
});
.def("get_debug_info_golden", &::tt::runtime::Binary::getDebugInfoGolden,
py::return_value_policy::reference);
py::class_<tt::runtime::SystemDesc>(m, "SystemDesc")
.def_property_readonly("version", &tt::runtime::SystemDesc::getVersion)
.def_property_readonly("ttmlir_git_hash",
Expand All @@ -66,4 +52,76 @@ PYBIND11_MODULE(_C, m) {
.handle); // Dereference capsule, and then dereference shared_ptr*
});
m.def("load_system_desc_from_path", &tt::runtime::SystemDesc::loadFromPath);

/**
* Binding for the `GoldenTensor` type
*/
py::class_<tt::target::GoldenTensor>(m, "GoldenTensor", py::buffer_protocol())
.def_property_readonly(
"name",
[](::tt::target::GoldenTensor const *t) -> std::string {
assert(t != nullptr && t->name() != nullptr);
return t->name()->str();
})
.def_property_readonly(
"shape",
[](::tt::target::GoldenTensor const *t) -> std::vector<int> {
assert(t != nullptr && t->shape() != nullptr);
return std::vector<int>(t->shape()->begin(), t->shape()->end());
})
.def_property_readonly(
"stride",
[](::tt::target::GoldenTensor const *t) -> std::vector<int> {
assert(t != nullptr && t->stride() != nullptr);
return std::vector<int>(t->stride()->begin(), t->stride()->end());
})
.def_property_readonly("dtype", &::tt::target::GoldenTensor::dtype)
.def_buffer([](tt::target::GoldenTensor const *t) -> py::buffer_info {
assert(t != nullptr && t->data() != nullptr && t->shape() != nullptr &&
t->stride() != nullptr);

// Format string to be passed to `py::buffer_info`
std::string format;

// Element size to be passed to `py::buffer_info`
size_t size;

switch (t->dtype()) {

case tt::target::DataType::UInt8:
format = py::format_descriptor<uint8_t>::format();
size = sizeof(uint8_t);
break;

case tt::target::DataType::UInt16:
format = py::format_descriptor<uint16_t>::format();
size = sizeof(uint16_t);
break;

case tt::target::DataType::UInt32:
format = py::format_descriptor<uint32_t>::format();
size = sizeof(uint32_t);
break;

case tt::target::DataType::Float32:
format = py::format_descriptor<float>::format();
size = sizeof(float);
break;

default:
throw std::runtime_error(
"Only 32-bit floats and unsigned ints are currently supported "
"for GoldenTensor bindings");
}

return py::buffer_info(
(void *)t->data()->data(), /* ptr to underlying data */
size, /* size of element */
format, /* format */
t->shape()->size(), /* rank */
*(t->shape()), /* shape */
*(t->stride()), /* stride of buffer */
false /* read only */
);
});
}
38 changes: 12 additions & 26 deletions runtime/tools/python/ttrt/common/golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
import json
import importlib.machinery
import sys
import signal
import os
import io
import subprocess
import time
import socket
from pkg_resources import get_distribution
import shutil
import atexit
import re
from functools import partial

Expand Down Expand Up @@ -148,26 +135,19 @@ def golden_partial_function(
op_context, program_context
)

if len(op_golden_tensor) == 0:
print("Golden tensor is empty - skipping golden comparison")
if op_golden_tensor is None:
print("Golden tensor is None - skipping golden comparison")
return

if len(op_output_tensor) == 0:
print("Output tensor is empty - skipping golden comparison")
return

if len(op_golden_tensor) != len(op_output_tensor):
print(
"Golden and output tensor sizes do not match - skipping golden comparison"
)
return
dtype = ttrt_datatype_to_torch_dtype(op_golden_tensor.dtype)

golden_tensor_torch = torch.frombuffer(op_golden_tensor, dtype=dtype).flatten()

golden_tensor_torch = torch.tensor(
op_golden_tensor, dtype=torch.float32
).flatten()
output_tensor_torch = torch.tensor(
op_output_tensor, dtype=torch.float32
).flatten()
output_tensor_torch = torch.tensor(op_output_tensor, dtype=dtype).flatten()

if golden_runtime_config.save_golden_tensors:
torch.save(
Expand All @@ -179,6 +159,12 @@ def golden_partial_function(
f"{golden_runtime_config.artifact_dir}/{loc}_device.pt",
)

if golden_tensor_torch.shape != output_tensor_torch.shape:
print(
"Golden and output tensor shapes do not match - skipping golden comparison"
)
return

_, _, cal_pcc, output_str = get_atol_rtol_pcc(
golden_tensor_torch, output_tensor_torch
)
Expand Down
24 changes: 9 additions & 15 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,6 @@
# SPDX-License-Identifier: Apache-2.0

import os
import json
import importlib.machinery
import sys
import signal
import os
import io
import subprocess
import time
import socket
from pkg_resources import get_distribution
import shutil
import atexit

from ttrt.common.util import *
from ttrt.common.query import Query
Expand Down Expand Up @@ -406,10 +394,16 @@ def _execute(binaries):
f"input_{i}"
)

if len(golden_tensor) != 0:
golden_inputs.append(
torch.tensor(golden_tensor, dtype=torch.float32)
if golden_tensor is not None:

dtype = ttrt_datatype_to_torch_dtype(
golden_tensor.dtype
)

golden_tensor_torch = torch.frombuffer(
golden_tensor, dtype=dtype
)
golden_inputs.append(golden_tensor_torch)

program.populate_inputs(
Run.TorchInitializer.get_initilizer(self["--init"]),
Expand Down
50 changes: 42 additions & 8 deletions runtime/tools/python/ttrt/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@
import json
import importlib.machinery
import importlib.util
import sys
import signal
import os
import io
import subprocess
import time
import socket
from pkg_resources import get_distribution
import shutil

import ttrt.binary
import torch


# environment tweaks
if "LOGGER_LEVEL" not in os.environ:
Expand All @@ -25,6 +19,46 @@
os.environ["TT_METAL_LOGGER_LEVEL"] = "FATAL"


def ttrt_datatype_to_torch_dtype(dtype) -> torch.dtype:
from ttrt.runtime._C import DataType

"""Converts a PyBound `::tt::target::DataType` into a `torch.dtype`.
Currently, only `float32`, `uint32`, `uint16`, & `uint8` are supported for
this conversion
Arguments
---------
dtype : DataType
A datatype from the PyBound `DataType` enum from ttrt
Returns
-------
A `torch.dtype` corresponding to `dtype`
Throws
------
A `ValueError` if `dtype` is not one of `Float32`, `UInt32`, `UInt16`, or `UInt8`
"""
match dtype:
case DataType.Float32:
return torch.float32
case DataType.UInt32:
return torch.uint32
case DataType.UInt16:
return torch.uint16
case DataType.UInt8:
return torch.uint8
case _:
raise ValueError(
"Only F32 and unsigned integers are supported in the runtime"
)


def get_ttrt_metal_home_path():
package_name = "ttrt"
spec = importlib.util.find_spec(package_name)
Expand Down
1 change: 1 addition & 0 deletions runtime/tools/python/ttrt/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
deallocate_tensor,
WorkaroundEnv,
get_op_loc_info,
unregister_hooks,
)
except ModuleNotFoundError:
raise ImportError(
Expand Down
2 changes: 2 additions & 0 deletions runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,6 @@ PYBIND11_MODULE(_C, m) {
::tt::runtime::debug::Hooks::get().unregisterHooks();
};
m.add_object("_cleanup", py::capsule(cleanup_callback));
m.def("unregister_hooks",
[]() { ::tt::runtime::debug::Hooks::get().unregisterHooks(); });
}

0 comments on commit 8bc14e2

Please sign in to comment.