diff --git a/python_bindings/src/halide/halide_/PyType.cpp b/python_bindings/src/halide/halide_/PyType.cpp index f2592e304ef0..09ac62f34570 100644 --- a/python_bindings/src/halide/halide_/PyType.cpp +++ b/python_bindings/src/halide/halide_/PyType.cpp @@ -32,6 +32,9 @@ std::string halide_type_to_string(const Type &type) { case halide_type_float: stream << "float"; break; + case halide_type_bfloat: + stream << "bfloat"; + break; case halide_type_handle: if (type.handle_type) { stream << type.handle_type->inner_name.name; @@ -67,6 +70,7 @@ void define_type(py::module &m) { .def("is_vector", &Type::is_vector) .def("is_scalar", &Type::is_scalar) .def("is_float", &Type::is_float) + .def("is_bfloat", &Type::is_bfloat) .def("is_int", &Type::is_int) .def("is_uint", &Type::is_uint) .def("is_handle", &Type::is_handle) @@ -94,6 +98,7 @@ void define_type(py::module &m) { m.def("Int", Int, py::arg("bits"), py::arg("lanes") = 1); m.def("UInt", UInt, py::arg("bits"), py::arg("lanes") = 1); m.def("Float", Float, py::arg("bits"), py::arg("lanes") = 1); + m.def("BFloat", BFloat, py::arg("bits"), py::arg("lanes") = 1); m.def("Bool", Bool, py::arg("lanes") = 1); m.def("Handle", make_handle, py::arg("lanes") = 1);