From 7ca95d8658db5383325b7ca51cef21aaeaab89ca Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 2 Jun 2024 14:39:44 -0700 Subject: [PATCH] Expose BFloat in Python bindings (#8255) There are two parts to support for BFloat16 in Python: 1) Ability to define kernels and AOT compile them [fixed in this PR] 2) Ability to call kernels from Python This fixes part 1, which is what I need for my use case. Part 2 is blocked on bfloat16 support in Python buffer protocols. See #6849 for more details. --- python_bindings/src/halide/halide_/PyType.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python_bindings/src/halide/halide_/PyType.cpp b/python_bindings/src/halide/halide_/PyType.cpp index f2592e304ef0..09ac62f34570 100644 --- a/python_bindings/src/halide/halide_/PyType.cpp +++ b/python_bindings/src/halide/halide_/PyType.cpp @@ -32,6 +32,9 @@ std::string halide_type_to_string(const Type &type) { case halide_type_float: stream << "float"; break; + case halide_type_bfloat: + stream << "bfloat"; + break; case halide_type_handle: if (type.handle_type) { stream << type.handle_type->inner_name.name; @@ -67,6 +70,7 @@ void define_type(py::module &m) { .def("is_vector", &Type::is_vector) .def("is_scalar", &Type::is_scalar) .def("is_float", &Type::is_float) + .def("is_bfloat", &Type::is_bfloat) .def("is_int", &Type::is_int) .def("is_uint", &Type::is_uint) .def("is_handle", &Type::is_handle) @@ -94,6 +98,7 @@ void define_type(py::module &m) { m.def("Int", Int, py::arg("bits"), py::arg("lanes") = 1); m.def("UInt", UInt, py::arg("bits"), py::arg("lanes") = 1); m.def("Float", Float, py::arg("bits"), py::arg("lanes") = 1); + m.def("BFloat", BFloat, py::arg("bits"), py::arg("lanes") = 1); m.def("Bool", Bool, py::arg("lanes") = 1); m.def("Handle", make_handle, py::arg("lanes") = 1);