Skip to content

Commit

Permalink
Bind name, shape & stride for GoldenTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ctodTT committed Nov 22, 2024
1 parent 73b3680 commit d7580d3
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion runtime/tools/python/ttrt/binary/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,39 @@ PYBIND11_MODULE(_C, m) {
* Binding for the `GoldenTensor` type
*/
py::class_<tt::target::GoldenTensor>(m, "GoldenTensor", py::buffer_protocol())
// TODO: bind shape, name, and stride properties
.def_property_readonly(
"name",
[](::tt::target::GoldenTensor const *t) -> std::string {
if (t == nullptr) {
throw std::runtime_error("GoldenTensor cannot be null");
}
if (t->name() == nullptr) {
throw std::runtime_error("GoldenTensor `name` is null");
}
return t->name()->str();
})
.def_property_readonly(
"shape",
[](::tt::target::GoldenTensor const *t) -> std::vector<int> {
if (t == nullptr) {
throw std::runtime_error("GoldenTensor cannot be null");
}
if (t->shape() == nullptr) {
throw std::runtime_error("GoldenTensor `shape` pointer is null");
}
return std::vector<int>(t->shape()->begin(), t->shape()->end());
})
.def_property_readonly(
"stride",
[](::tt::target::GoldenTensor const *t) -> std::vector<int> {
if (t == nullptr) {
throw std::runtime_error("GoldenTensor cannot be null");
}
if (t->stride() == nullptr) {
throw std::runtime_error("GoldenTensor `stride` pointer is null");
}
return std::vector<int>(t->stride()->begin(), t->stride()->end());
})
.def_property_readonly("dtype", &::tt::target::GoldenTensor::dtype)
.def_buffer([](tt::target::GoldenTensor const *t) -> py::buffer_info {
// NULL checks
Expand Down

0 comments on commit d7580d3

Please sign in to comment.