From 9965dd4cabdad957cafad8d7cab4109168914c29 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 23 May 2023 13:19:40 -0700 Subject: [PATCH] Switch jaxlib to use nanobind instead of pybind11. nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12. PiperOrigin-RevId: 534535677 --- WORKSPACE | 8 +- examples/jax_cpp/BUILD | 4 +- jaxlib/BUILD | 23 +- jaxlib/absl_status_casters.h | 218 ++++++++++++++++++ jaxlib/cpu/BUILD | 8 +- jaxlib/cpu/ducc_fft.cc | 22 +- jaxlib/cpu/lapack.cc | 25 +- jaxlib/cuda/BUILD | 35 +-- jaxlib/gpu/blas.cc | 27 ++- jaxlib/gpu/blas_kernels.cc | 2 - jaxlib/gpu/linalg.cc | 14 +- jaxlib/gpu/prng.cc | 15 +- jaxlib/gpu/rnn.cc | 20 +- jaxlib/gpu/solver.cc | 59 ++--- jaxlib/gpu/sparse.cc | 79 ++++--- jaxlib/gpu/triton.cc | 123 +++++----- ...11_helpers.h => kernel_nanobind_helpers.h} | 35 ++- jaxlib/rocm/BUILD.bazel | 23 +- jaxlib/utils.cc | 40 ++-- third_party/nanobind/BUILD.bazel | 22 ++ third_party/nanobind/workspace.bzl | 26 +++ third_party/robin_map/BUILD.bazel | 17 ++ third_party/robin_map/workspace.bzl | 26 +++ 23 files changed, 620 insertions(+), 251 deletions(-) create mode 100644 jaxlib/absl_status_casters.h rename jaxlib/{kernel_pybind11_helpers.h => kernel_nanobind_helpers.h} (53%) create mode 100644 third_party/nanobind/BUILD.bazel create mode 100644 third_party/nanobind/workspace.bzl create mode 100644 third_party/robin_map/BUILD.bazel create mode 100644 third_party/robin_map/workspace.bzl diff --git a/WORKSPACE b/WORKSPACE index 09bb0f17115e..8a1cc536bf30 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -20,6 +20,12 @@ xla_workspace1() load("@xla//:workspace0.bzl", "xla_workspace0") xla_workspace0() - load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() + +load("//third_party/robin_map:workspace.bzl", robin_map = "repo") +robin_map() + +load("//third_party/nanobind:workspace.bzl", nanobind = "repo") +nanobind() + diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index 301f0e9d40d2..a36b6c1949c1 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -30,7 +30,7 @@ cc_binary( "@xla//xla/pjrt:tfrt_cpu_pjrt_client", "@xla//xla/service:hlo_proto_cc", "@xla//xla/tools:hlo_module_loader", - "@tsl///platform:logging", - "@tsl///platform:platform_port", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", ], ) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 93edfb3fea9d..44a75fe3e7ce 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -101,8 +101,22 @@ exports_files([ ]) cc_library( - name = "kernel_pybind11_helpers", - hdrs = ["kernel_pybind11_helpers.h"], + name = "absl_status_casters", + hdrs = ["absl_status_casters.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "kernel_nanobind_helpers", + hdrs = ["kernel_nanobind_helpers.h"], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -110,8 +124,9 @@ cc_library( features = ["-use_header_modules"], deps = [ ":kernel_helpers", + "@tsl//tsl/python/lib/core:numpy", "@com_google_absl//absl/base", - "@pybind11", + "@nanobind", ], ) @@ -163,7 +178,7 @@ pybind_extension( "@xla//third_party/python_runtime:headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", - "@pybind11", + "@nanobind", ], ) diff --git a/jaxlib/absl_status_casters.h b/jaxlib/absl_status_casters.h new file mode 100644 index 000000000000..1ed3c0a0ad38 --- /dev/null +++ b/jaxlib/absl_status_casters.h @@ -0,0 +1,218 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_ABSL_STATUS_CASTERS_H_ +#define JAXLIB_ABSL_STATUS_CASTERS_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace jax { + +// C++ -> Python caster helpers. +// +// Failing statuses become Python exceptions; OK Status() becomes None. +// +// Given there can be only a single global pybind11 type_caster for the +// `absl::Status` type, and given XLA wants a custom exception being raised, +// we use a dedicated helper to implement this feature without relying on a +// global `type_caster`. +// +// For example: +// +// - Functions without arguments: +// m.def("my_func", []() { ThrowIfError(MyFunc()); } +// - Classes with a single argument: +// py_class.def("delete", [](Buffer& self) { +// ThrowIfError(self.Delete()); +// } +// +// For functions with more arguments, you can either inline the arguments, +// or use the `ThrowIfErrorWrapper` wrapper defined below: +// +// m.def("my_func", ThrowIfErrorWrapper(MyFunc)); +// +// Nonstatic member functions can be wrapped by passing a +// pointer-to-member-function: +// ThrowIfErrorWrapper(&MyClass::MyMethod) + +inline void ThrowIfError(absl::Status src) { + if (!src.ok()) { + throw std::runtime_error(src.ToString()); + } +} + +// If one does not want to have to define a lambda specifying the inputs +// arguments, on can use the `ThrowIfErrorWrapper` wrapper. +// +// There are three specializations: +// - For free functions, `Sig` is the function type and `F` is `Sig&`. +// - For callable types, `Sig` is the pointer to member function type +// and `F` is the type of the callable. +// - For a nonstatic member function of a class `C`, `Sig` is the function type +// and `F` is Sig C::*. +// +// In the first two cases, the wrapper returns a callable with signature `Sig`; +// in the third case, the wrapper returns callable with a modified signature +// that takes a C instance as the first argument. +template +struct ThrowIfErrorWrapper; + +// C++17 "deduction guide" that guides class template argument deduction (CTAD) +// For free functions. +template +ThrowIfErrorWrapper(F) -> ThrowIfErrorWrapper; + +// For callable types (with operator()). +template +ThrowIfErrorWrapper(absl::Status (&)(Args...)) + -> ThrowIfErrorWrapper; + +// For unbound nonstatic member functions. +template +ThrowIfErrorWrapper(absl::Status (C::*)(Args...)) + -> ThrowIfErrorWrapper; + +// Template specializations. + +// For free functions. +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (&f)(Args...)) : func(f) {} + void operator()(Args... args) { + ThrowIfError(func(std::forward(args)...)); + } + absl::Status (&func)(Args...); +}; + +// For callable types (with operator()), non-const and const versions. +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} + void operator()(Args... args) { + ThrowIfError(func(std::forward(args)...)); + } + F func; +}; +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} + void operator()(Args... args) const { + ThrowIfError(func(std::forward(args)...)); + } + F func; +}; + +// For unbound nonstatic member functions, non-const and const versions. +// `ptmf` stands for "pointer to member function". +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (C::*ptmf)(Args...)) : ptmf(ptmf) {} + void operator()(C& instance, Args... args) { + ThrowIfError((instance.*ptmf)(std::forward(args)...)); + } + absl::Status (C::*ptmf)(Args...); +}; +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (C::*ptmf)(Args...) const) + : ptmf(ptmf) {} + void operator()(const C& instance, Args... args) const { + ThrowIfError((instance.*ptmf)(std::forward(args)...)); + } + absl::Status (C::*ptmf)(Args...) const; +}; + +// Utilities for `StatusOr`. +template +T ValueOrThrow(absl::StatusOr v) { + if (!v.ok()) { + throw std::runtime_error(v.status().ToString()); + } + return std::move(v).value(); +} + +template +struct ValueOrThrowWrapper; + +template +ValueOrThrowWrapper(F) -> ValueOrThrowWrapper; + +template +ValueOrThrowWrapper(absl::StatusOr (&)(Args...)) + -> ValueOrThrowWrapper(Args...), + absl::StatusOr (&)(Args...)>; + +template +ValueOrThrowWrapper(absl::StatusOr (C::*)(Args...)) + -> ValueOrThrowWrapper(Args...), C>; + +// Deduction guide for const methods. +template +ValueOrThrowWrapper(absl::StatusOr (C::*)(Args...) const) + -> ValueOrThrowWrapper(Args...) const, C>; + +template +struct ValueOrThrowWrapper(Args...), + absl::StatusOr (&)(Args...)> { + explicit ValueOrThrowWrapper(absl::StatusOr (&f)(Args...)) : func(f) {} + R operator()(Args... args) const { + return ValueOrThrow(func(std::forward(args)...)); + } + absl::StatusOr (&func)(Args...); +}; +template +struct ValueOrThrowWrapper (C::*)(Args...), F> { + explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} + R operator()(Args... args) const { + return ValueOrThrow(func(std::forward(args)...)); + } + F func; +}; +template +struct ValueOrThrowWrapper (C::*)(Args...) const, F> { + explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} + R operator()(Args... args) const { + return ValueOrThrow(func(std::forward(args)...)); + } + F func; +}; + +// For unbound nonstatic member functions, non-const and const versions. +// `ptmf` stands for "pointer to member function". +template +struct ValueOrThrowWrapper(Args...), C> { + explicit ValueOrThrowWrapper(absl::StatusOr (C::*ptmf)(Args...)) + : ptmf(ptmf) {} + R operator()(C& instance, Args... args) { + return ValueOrThrow((instance.*ptmf)(std::forward(args)...)); + } + absl::StatusOr (C::*ptmf)(Args...); +}; +template +struct ValueOrThrowWrapper(Args...) const, C> { + explicit ValueOrThrowWrapper(absl::StatusOr (C::*ptmf)(Args...) const) + : ptmf(ptmf) {} + R operator()(const C& instance, Args... args) const { + return ValueOrThrow((instance.*ptmf)(std::forward(args)...)); + } + absl::StatusOr (C::*ptmf)(Args...) const; +}; + +} // namespace jax + +#endif // JAXLIB_ABSL_STATUS_CASTERS_H_ diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index bfdd06333428..dfef69418d60 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -57,8 +57,8 @@ pybind_extension( module_name = "_lapack", deps = [ ":lapack_kernels", - "//jaxlib:kernel_pybind11_helpers", - "@pybind11", + "//jaxlib:kernel_nanobind_helpers", + "@nanobind", ], ) @@ -95,9 +95,9 @@ pybind_extension( deps = [ ":ducc_fft_flatbuffers_cc", ":ducc_fft_kernels", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@flatbuffers//:runtime_cc", - "@pybind11", + "@nanobind", ], ) diff --git a/jaxlib/cpu/ducc_fft.cc b/jaxlib/cpu/ducc_fft.cc index 674bc5f9e1a2..33e73c5f4214 100644 --- a/jaxlib/cpu/ducc_fft.cc +++ b/jaxlib/cpu/ducc_fft.cc @@ -16,19 +16,19 @@ limitations under the License. #include #include -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/vector.h" #include "jaxlib/cpu/ducc_fft_generated.h" #include "jaxlib/cpu/ducc_fft_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" +#include "jaxlib/kernel_nanobind_helpers.h" -namespace py = pybind11; +namespace nb = nanobind; namespace jax { namespace { -py::bytes BuildDynamicDuccFftDescriptor( +nb::bytes BuildDynamicDuccFftDescriptor( const uint32_t ndims, bool is_double, int fft_type, const std::vector &axes, @@ -42,12 +42,12 @@ py::bytes BuildDynamicDuccFftDescriptor( descriptor.forward = forward; flatbuffers::FlatBufferBuilder fbb; fbb.Finish(DynamicDuccFftDescriptor::Pack(fbb, &descriptor)); - return py::bytes(reinterpret_cast(fbb.GetBufferPointer()), + return nb::bytes(reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); } -py::dict Registrations() { - pybind11::dict dict; +nb::dict Registrations() { + nb::dict dict; // TODO(b/287702203): this must be kept until EOY 2023 for backwards // of serialized functions using fft. dict["ducc_fft"] = EncapsulateFunction(DuccFft); @@ -55,11 +55,11 @@ py::dict Registrations() { return dict; } -PYBIND11_MODULE(_ducc_fft, m) { +NB_MODULE(_ducc_fft, m) { m.def("registrations", &Registrations); m.def("dynamic_ducc_fft_descriptor", &BuildDynamicDuccFftDescriptor, - py::arg("ndims"), py::arg("is_double"), py::arg("fft_type"), - py::arg("axes"), py::arg("forward")); + nb::arg("ndims"), nb::arg("is_double"), nb::arg("fft_type"), + nb::arg("axes"), nb::arg("forward")); } } // namespace diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 3e3318aec683..c2879783abf8 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -15,25 +15,25 @@ limitations under the License. #include -#include "pybind11/pybind11.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" +#include "jaxlib/kernel_nanobind_helpers.h" namespace jax { namespace { -namespace py = pybind11; +namespace nb = nanobind; void GetLapackKernelsFromScipy() { static bool initialized = false; // Protected by GIL if (initialized) return; - py::module cython_blas = py::module::import("scipy.linalg.cython_blas"); + nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas"); // Technically this is a Cython-internal API. However, it seems highly likely // it will remain stable because Cython itself needs API stability for // cross-package imports to work in the first place. - py::dict blas_capi = cython_blas.attr("__pyx_capi__"); + nb::dict blas_capi = cython_blas.attr("__pyx_capi__"); auto blas_ptr = [&](const char* name) { - return py::capsule(blas_capi[name]).get_pointer(); + return nb::cast(blas_capi[name]).data(); }; Trsm::fn = reinterpret_cast::FnType*>(blas_ptr("strsm")); Trsm::fn = reinterpret_cast::FnType*>(blas_ptr("dtrsm")); @@ -42,10 +42,11 @@ void GetLapackKernelsFromScipy() { Trsm>::fn = reinterpret_cast>::FnType*>(blas_ptr("ztrsm")); - py::module cython_lapack = py::module::import("scipy.linalg.cython_lapack"); - py::dict lapack_capi = cython_lapack.attr("__pyx_capi__"); + nb::module_ cython_lapack = + nb::module_::import_("scipy.linalg.cython_lapack"); + nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__"); auto lapack_ptr = [&](const char* name) { - return py::capsule(lapack_capi[name]).get_pointer(); + return nb::cast(lapack_capi[name]).data(); }; Getrf::fn = reinterpret_cast::FnType*>(lapack_ptr("sgetrf")); @@ -151,8 +152,8 @@ void GetLapackKernelsFromScipy() { initialized = true; } -py::dict Registrations() { - py::dict dict; +nb::dict Registrations() { + nb::dict dict; dict["blas_strsm"] = EncapsulateFunction(Trsm::Kernel); dict["blas_dtrsm"] = EncapsulateFunction(Trsm::Kernel); dict["blas_ctrsm"] = EncapsulateFunction(Trsm>::Kernel); @@ -224,7 +225,7 @@ py::dict Registrations() { return dict; } -PYBIND11_MODULE(_lapack, m) { +NB_MODULE(_lapack, m) { // Populates the LAPACK kernels from scipy on first call. m.def("initialize", GetLapackKernelsFromScipy); diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 9b543013bb56..33f37975583f 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -114,11 +114,12 @@ pybind_extension( deps = [ ":cublas_kernels", ":cuda_vendor", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@xla//xla/stream_executor/cuda:cublas_lib", + "@tsl//tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", - "@pybind11", + "@nanobind", ], ) @@ -153,11 +154,11 @@ pybind_extension( deps = [ ":cuda_vendor", ":cudnn_rnn_kernels", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:absl_status_casters", + "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", - "@pybind11", - "@pybind11_abseil//pybind11_abseil:status_casters", + "@nanobind", ], ) @@ -199,13 +200,14 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":cusolver_kernels", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@xla//xla/stream_executor/cuda:cudart_stub", "@xla//xla/stream_executor/cuda:cusolver_lib", + "@tsl//tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", - "@pybind11", + "@nanobind", ], ) @@ -248,9 +250,10 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":cusparse_kernels", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@xla//xla/stream_executor/cuda:cudart_stub", "@xla//xla/stream_executor/cuda:cusparse_lib", + "@tsl//tsl/python/lib/core:numpy", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -261,7 +264,7 @@ pybind_extension( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", - "@pybind11", + "@nanobind", ], ) @@ -310,10 +313,10 @@ pybind_extension( ":cuda_lu_pivot_kernels", ":cuda_lu_pivot_kernels_impl", ":cuda_vendor", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@xla//xla/stream_executor/cuda:cudart_stub", "@local_config_cuda//cuda:cuda_headers", - "@pybind11", + "@nanobind", ], ) @@ -360,10 +363,10 @@ pybind_extension( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@xla//xla/stream_executor/cuda:cudart_stub", "@local_config_cuda//cuda:cuda_headers", - "@pybind11", + "@nanobind", ], ) @@ -444,11 +447,11 @@ pybind_extension( ":cuda_vendor", ":triton_kernels", ":triton_utils", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:absl_status_casters", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", "@com_google_absl//absl/status:statusor", - "@pybind11", - "@pybind11_abseil//pybind11_abseil:status_casters", + "@nanobind", ], ) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index 4ba02ed64c8c..9f83b86f61d4 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -18,23 +18,23 @@ limitations under the License. #include #include -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_pybind11_helpers.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "tsl/python/lib/core/numpy.h" namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -namespace py = pybind11; +namespace nb = nanobind; // Converts a NumPy dtype to a Type. -BlasType DtypeToBlasType(const py::dtype& np_type) { +BlasType DtypeToBlasType(const dtype& np_type) { static auto* types = new absl::flat_hash_map, BlasType>({ {{'f', 4}, BlasType::F32}, {{'f', 8}, BlasType::F64}, @@ -43,14 +43,15 @@ BlasType DtypeToBlasType(const py::dtype& np_type) { }); auto it = types->find({np_type.kind(), np_type.itemsize()}); if (it == types->end()) { + nb::str repr = nb::repr(np_type); throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", py::repr(np_type))); + absl::StrFormat("Unsupported dtype %s", repr.c_str())); } return it->second; } // Returns the descriptor for a GetrfBatched operation. -std::pair BuildGetrfBatchedDescriptor(const py::dtype& dtype, +std::pair BuildGetrfBatchedDescriptor(const dtype& dtype, int b, int n) { BlasType type = DtypeToBlasType(dtype); size_t size = b * sizeof(void*); @@ -58,21 +59,23 @@ std::pair BuildGetrfBatchedDescriptor(const py::dtype& dtype, } // Returns the descriptor for a GetrfBatched operation. -std::pair BuildGeqrfBatchedDescriptor(const py::dtype& dtype, +std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, int b, int m, int n) { BlasType type = DtypeToBlasType(dtype); size_t size = b * sizeof(void*); return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})}; } -py::dict Registrations() { - py::dict dict; +nb::dict Registrations() { + nb::dict dict; dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); return dict; } -PYBIND11_MODULE(_blas, m) { +NB_MODULE(_blas, m) { + tsl::ImportNumpy(); + m.def("registrations", &Registrations); m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc index 93f4f3ceb3d8..329051a0aaae 100644 --- a/jaxlib/gpu/blas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -56,8 +56,6 @@ namespace JAX_GPU_NAMESPACE { namespace { -// Converts a NumPy dtype to a BlasType. - int SizeOfBlasType(BlasType type) { switch (type) { case BlasType::F32: diff --git a/jaxlib/gpu/linalg.cc b/jaxlib/gpu/linalg.cc index d297911d6752..6397647105ad 100644 --- a/jaxlib/gpu/linalg.cc +++ b/jaxlib/gpu/linalg.cc @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "pybind11/pybind11.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/lu_pivot_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" +#include "jaxlib/kernel_nanobind_helpers.h" namespace jax { namespace JAX_GPU_NAMESPACE { namespace { +namespace nb = nanobind; + std::string BuildLuPivotsToPermutationDescriptor( std::int64_t batch_size, std::int32_t pivot_size, std::int32_t permutation_size) { @@ -29,21 +31,21 @@ std::string BuildLuPivotsToPermutationDescriptor( batch_size, pivot_size, permutation_size}); } -pybind11::dict Registrations() { - pybind11::dict dict; +nb::dict Registrations() { + nb::dict dict; dict[JAX_GPU_PREFIX "_lu_pivots_to_permutation"] = EncapsulateFunction(LuPivotsToPermutation); return dict; } -PYBIND11_MODULE(_linalg, m) { +NB_MODULE(_linalg, m) { m.def("registrations", &Registrations); m.def("lu_pivots_to_permutation_descriptor", [](std::int64_t batch_size, std::int32_t pivot_size, std::int32_t permutation_size) { std::string result = BuildLuPivotsToPermutationDescriptor( batch_size, pivot_size, permutation_size); - return pybind11::bytes(result); + return nb::bytes(result.data(), result.size()); }); } diff --git a/jaxlib/gpu/prng.cc b/jaxlib/gpu/prng.cc index d96d7326d3ac..8aec8d81f8a3 100644 --- a/jaxlib/gpu/prng.cc +++ b/jaxlib/gpu/prng.cc @@ -13,29 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "pybind11/pybind11.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/prng_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" +#include "jaxlib/kernel_nanobind_helpers.h" namespace jax { namespace JAX_GPU_NAMESPACE { namespace { +namespace nb = nanobind; + std::string BuildThreeFry2x32Descriptor(std::int64_t n) { return PackDescriptorAsString(ThreeFry2x32Descriptor{n}); } -pybind11::dict Registrations() { - pybind11::dict dict; +nb::dict Registrations() { + nb::dict dict; dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32); return dict; } -PYBIND11_MODULE(_prng, m) { +NB_MODULE(_prng, m) { m.def("registrations", &Registrations); m.def("threefry2x32_descriptor", [](std::int64_t n) { std::string result = BuildThreeFry2x32Descriptor(n); - return pybind11::bytes(result); + return nb::bytes(result.data(), result.size()); }); } diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index da1e662b0b65..c35292becc16 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -13,20 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" +#include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/rnn_kernels.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_pybind11_helpers.h" -#include "pybind11_abseil/status_casters.h" +#include "jaxlib/kernel_nanobind_helpers.h" namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -namespace py = pybind11; +namespace nb = nanobind; -py::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, +nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, int workspace_size, int reserve_space_size) { @@ -35,18 +35,18 @@ py::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, bidirectional, workspace_size, reserve_space_size}); } -py::dict Registrations() { - py::dict dict; +nb::dict Registrations() { + nb::dict dict; dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward); dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward); return dict; } -PYBIND11_MODULE(_rnn, m) { +NB_MODULE(_rnn, m) { m.def("registrations", &Registrations); m.def("build_rnn_descriptor", &BuildRnnDescriptor); m.def("compute_rnn_workspace_reserve_space_sizes", - &RnnComputeWorkspaceReserveSpaceSizes); + ValueOrThrowWrapper(RnnComputeWorkspaceReserveSpaceSizes)); } } // namespace diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index d986ebd519a5..5fa0683a18a7 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -19,23 +19,24 @@ limitations under the License. #include #include -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_pybind11_helpers.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "tsl/python/lib/core/numpy.h" namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -namespace py = pybind11; + +namespace nb = nanobind; // Converts a NumPy dtype to a Type. -SolverType DtypeToSolverType(const py::dtype& np_type) { +SolverType DtypeToSolverType(const dtype& np_type) { static auto* types = new absl::flat_hash_map, SolverType>({ {{'f', 4}, SolverType::F32}, @@ -45,8 +46,9 @@ SolverType DtypeToSolverType(const py::dtype& np_type) { }); auto it = types->find({np_type.kind(), np_type.itemsize()}); if (it == types->end()) { + nb::str repr = nb::repr(np_type); throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", py::repr(np_type))); + absl::StrFormat("Unsupported dtype %s", repr.c_str())); } return it->second; } @@ -54,8 +56,8 @@ SolverType DtypeToSolverType(const py::dtype& np_type) { // getrf: LU decomposition // Returns the workspace size and a descriptor for a getrf operation. -std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, - int m, int n) { +std::pair BuildGetrfDescriptor(const dtype& dtype, int b, int m, + int n) { SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -93,8 +95,8 @@ std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, // geqrf: QR decomposition // Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildGeqrfDescriptor(const py::dtype& dtype, int b, - int m, int n) { +std::pair BuildGeqrfDescriptor(const dtype& dtype, int b, int m, + int n) { SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -134,7 +136,7 @@ std::pair BuildGeqrfDescriptor(const py::dtype& dtype, int b, // csrlsvqr: Linear system solve via Sparse QR // Returns a descriptor for a csrlsvqr operation. -py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA, +nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA, int reorder, double tol) { SolverType type = DtypeToSolverType(dtype); return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol}); @@ -145,8 +147,8 @@ py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA, // orgqr/ungqr: apply elementary Householder transformations // Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, - int m, int n, int k) { +std::pair BuildOrgqrDescriptor(const dtype& dtype, int b, int m, + int n, int k) { SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -188,8 +190,8 @@ std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd // Returns the workspace size and a descriptor for a syevd operation. -std::pair BuildSyevdDescriptor(const py::dtype& dtype, - bool lower, int b, int n) { +std::pair BuildSyevdDescriptor(const dtype& dtype, bool lower, + int b, int n) { SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -227,8 +229,8 @@ std::pair BuildSyevdDescriptor(const py::dtype& dtype, // Supports batches of matrices up to size 32. // Returns the workspace size and a descriptor for a syevj_batched operation. -std::pair BuildSyevjDescriptor(const py::dtype& dtype, - bool lower, int batch, int n) { +std::pair BuildSyevjDescriptor(const dtype& dtype, bool lower, + int batch, int n) { SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -294,8 +296,8 @@ std::pair BuildSyevjDescriptor(const py::dtype& dtype, // Singular value decomposition using QR algorithm: gesvd // Returns the workspace size and a descriptor for a gesvd operation. -std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, - int m, int n, bool compute_uv, +std::pair BuildGesvdDescriptor(const dtype& dtype, int b, int m, + int n, bool compute_uv, bool full_matrices) { SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); @@ -339,9 +341,9 @@ std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, // Singular value decomposition using Jacobi algorithm: gesvdj // Returns the workspace size and a descriptor for a gesvdj operation. -std::pair BuildGesvdjDescriptor(const py::dtype& dtype, - int batch, int m, int n, - bool compute_uv, int econ) { +std::pair BuildGesvdjDescriptor(const dtype& dtype, int batch, + int m, int n, bool compute_uv, + int econ) { SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -423,8 +425,8 @@ std::pair BuildGesvdjDescriptor(const py::dtype& dtype, #endif // JAX_GPU_CUDA // Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildSytrdDescriptor(const py::dtype& dtype, - bool lower, int b, int n) { +std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, + int b, int n) { SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -457,8 +459,8 @@ std::pair BuildSytrdDescriptor(const py::dtype& dtype, return {lwork, PackDescriptor(SytrdDescriptor{type, uplo, b, n, n, lwork})}; } -py::dict Registrations() { - py::dict dict; +nb::dict Registrations() { + nb::dict dict; dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf); dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); @@ -474,7 +476,8 @@ py::dict Registrations() { return dict; } -PYBIND11_MODULE(_solver, m) { +NB_MODULE(_solver, m) { + tsl::ImportNumpy(); m.def("registrations", &Registrations); m.def("build_getrf_descriptor", &BuildGetrfDescriptor); m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 7bd1d463d7b5..aa90440ca671 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -19,9 +19,8 @@ limitations under the License. #include #include -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "absl/base/casts.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" @@ -30,15 +29,16 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_pybind11_helpers.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "tsl/python/lib/core/numpy.h" -namespace py = pybind11; +namespace nb = nanobind; namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -gpusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) { +gpusparseIndexType_t DtypeToCuSparseIndexType(const dtype& np_type) { static auto* types = new absl::flat_hash_map, gpusparseIndexType_t>({ {{'u', 2}, GPUSPARSE_INDEX_16U}, @@ -47,13 +47,14 @@ gpusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) { }); auto it = types->find({np_type.kind(), np_type.itemsize()}); if (it == types->end()) { + nb::str repr = nb::repr(np_type); throw std::invalid_argument( - absl::StrFormat("Unsupported index dtype: %s", py::repr(np_type))); + absl::StrFormat("Unsupported index dtype: %s", repr.c_str())); } return it->second; } -gpuDataType DtypeToCudaDataType(const py::dtype& np_type) { +gpuDataType DtypeToCudaDataType(const dtype& np_type) { static auto* types = new absl::flat_hash_map, gpuDataType>({ {{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F}, @@ -69,14 +70,15 @@ gpuDataType DtypeToCudaDataType(const py::dtype& np_type) { }); auto it = types->find({np_type.kind(), np_type.itemsize()}); if (it == types->end()) { + nb::str repr = nb::repr(np_type); throw std::invalid_argument( - absl::StrFormat("Unsupported data dtype: %s", py::repr(np_type))); + absl::StrFormat("Unsupported data dtype: %s", repr.c_str())); } return it->second; } // Returns the descriptor for a Sparse matrix. -SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype, - const py::dtype& index_dtype, +SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype, + const dtype& index_dtype, int rows, int cols, int nnz, int batch_count, int batch_stride) { @@ -87,7 +89,7 @@ SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype, } // Returns the descriptor for a Dense matrix. -DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype, +DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, int rows, int cols, int batch_count, int batch_stride) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); @@ -95,7 +97,7 @@ DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype, } // Returns the descriptor for a Dense vector. -DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype, +DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, int size) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseVecDescriptor{value_type, size}; @@ -105,8 +107,8 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype, // CsrToDense: Convert CSR matrix to dense matrix // Returns the descriptor for a Sparse matrix. -std::pair BuildCsrToDenseDescriptor( - const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, +std::pair BuildCsrToDenseDescriptor( + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -182,8 +184,8 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, // CsrFromDense: Convert dense matrix to CSR matrix // Returns the descriptor for a CsrFromDense operation. -std::pair BuildCsrFromDenseDescriptor( - const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, +std::pair BuildCsrFromDenseDescriptor( + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -258,9 +260,9 @@ void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, // CsrMatvec: Product of CSR matrix and dense vector. // Returns the descriptor for a CsrMatvec operation. -std::pair BuildCsrMatvecDescriptor( - const py::dtype& data_dtype, const py::dtype& x_dtype, - const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, +std::pair BuildCsrMatvecDescriptor( + const dtype& data_dtype, const dtype& x_dtype, + const dtype& compute_dtype, const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -306,9 +308,9 @@ std::pair BuildCsrMatvecDescriptor( // CsrMatmat: Product of CSR matrix and dense matrix. // Returns the descriptor for a CsrMatmat operation. -std::pair BuildCsrMatmatDescriptor( - const py::dtype& data_dtype, const py::dtype& b_dtype, - const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, +std::pair BuildCsrMatmatDescriptor( + const dtype& data_dtype, const dtype& b_dtype, + const dtype& compute_dtype, const dtype& index_dtype, int rows, int cols, int BCcols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -358,8 +360,8 @@ std::pair BuildCsrMatmatDescriptor( // CooToDense: Convert COO matrix to dense matrix // Returns the descriptor for a CooToDense operation. -std::pair BuildCooToDenseDescriptor( - const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, +std::pair BuildCooToDenseDescriptor( + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -395,8 +397,8 @@ std::pair BuildCooToDenseDescriptor( // CooFromDense: Convert dense matrix to COO matrix // Returns the descriptor for a CooFromDense operation. -std::pair BuildCooFromDenseDescriptor( - const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, +std::pair BuildCooFromDenseDescriptor( + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -431,9 +433,9 @@ std::pair BuildCooFromDenseDescriptor( // CooMatvec: Product of COO matrix and dense vector. // Returns the descriptor for a CooMatvec operation. -std::pair BuildCooMatvecDescriptor( - const py::dtype& data_dtype, const py::dtype& x_dtype, - const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, +std::pair BuildCooMatvecDescriptor( + const dtype& data_dtype, const dtype& x_dtype, + const dtype& compute_dtype, const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); @@ -479,9 +481,9 @@ std::pair BuildCooMatvecDescriptor( // CooMatmat: Product of COO matrix and dense matrix. // Returns the descriptor for a CooMatmat operation. -std::pair BuildCooMatmatDescriptor( - const py::dtype& data_dtype, const py::dtype& b_dtype, - const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, +std::pair BuildCooMatmatDescriptor( + const dtype& data_dtype, const dtype& b_dtype, + const dtype& compute_dtype, const dtype& index_dtype, int rows, int cols, int BCcols, int nnz, bool transpose, int batch_count, int lhs_batch_stride, int rhs_batch_stride) { // Three batch modes are supported, C_i = A_i B, C_i = A B_i, and @@ -548,7 +550,7 @@ std::pair BuildCooMatmatDescriptor( #endif // if JAX_GPU_HAVE_SPARSE -py::bytes BuildGtsv2Descriptor(int b, int m, int n, int ldb) { +nb::bytes BuildGtsv2Descriptor(int b, int m, int n, int ldb) { return PackDescriptor(Gtsv2Descriptor{b, m, n, ldb}); } @@ -572,8 +574,8 @@ size_t Gtsv2BufferSizeF64(int m, int n, int ldb) { return Gtsv2BufferSize(gpusparseDgtsv2_bufferSizeExt, m, n, ldb); } -py::dict Registrations() { - py::dict dict; +nb::dict Registrations() { + nb::dict dict; #if JAX_GPU_HAVE_SPARSE dict[JAX_GPU_PREFIX "sparse_csr_todense"] = EncapsulateFunction(CsrToDense); dict[JAX_GPU_PREFIX "sparse_csr_fromdense"] = @@ -592,8 +594,9 @@ py::dict Registrations() { return dict; } -PYBIND11_MODULE(_sparse, m) { - m.attr("sparse_supported") = py::bool_(JAX_GPU_HAVE_SPARSE); +NB_MODULE(_sparse, m) { + tsl::ImportNumpy(); + m.attr("sparse_supported") = nb::cast(JAX_GPU_HAVE_SPARSE); m.def("registrations", &Registrations); #if JAX_GPU_HAVE_SPARSE m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor); diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 253a4dfa9ffa..2238c11f8d33 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -5,31 +5,33 @@ #include #include -#include "pybind11/pybind11.h" -#include "pybind11/pytypes.h" -#include "pybind11/stl.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" +#include "nanobind/stl/string.h" +#include "nanobind/stl/string_view.h" +#include "nanobind/stl/tuple.h" +#include "nanobind/stl/vector.h" #include "absl/status/statusor.h" +#include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/triton_kernels.h" #include "jaxlib/gpu/triton_utils.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_pybind11_helpers.h" -#include "pybind11_abseil/status_casters.h" // IWYU pragma: keep +#include "jaxlib/kernel_nanobind_helpers.h" #define CUDA_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) - -namespace py = pybind11; +namespace nb = nanobind; namespace jax::JAX_GPU_NAMESPACE { -PYBIND11_MODULE(_triton, m) { - py::class_(m, "TritonKernel") - .def(py::init(m, "TritonKernel") + .def(nb::init()); - py::class_(m, "TritonParameter"); + nb::class_(m, "TritonParameter"); m.def("create_array_parameter", [](size_t bytes_to_zero, size_t ptr_divisibility) { @@ -38,19 +40,19 @@ PYBIND11_MODULE(_triton, m) { }); m.def("create_scalar_parameter", - [](py::bool_ value, - std::string_view dtype) -> absl::StatusOr { + ValueOrThrowWrapper([](bool value, std::string_view dtype) + -> absl::StatusOr { if ((dtype == "i1") || (dtype == "B")) { - return KernelCall::Parameter{static_cast(value)}; + return KernelCall::Parameter{value}; } else { return absl::InvalidArgumentError(std::string("unknown dtype: ") + dtype.data()); } - }); + })); m.def("create_scalar_parameter", - [](py::int_ value, - std::string_view dtype) -> absl::StatusOr { + ValueOrThrowWrapper([](nb::int_ value, std::string_view dtype) + -> absl::StatusOr { if (dtype == "i32") { return KernelCall::Parameter{static_cast(value)}; } else if (dtype == "u32") { @@ -63,11 +65,11 @@ PYBIND11_MODULE(_triton, m) { return absl::InvalidArgumentError(std::string("unknown dtype: ") + dtype.data()); } - }); + })); m.def("create_scalar_parameter", - [](py::float_ value, - std::string_view dtype) -> absl::StatusOr { + ValueOrThrowWrapper([](double value, std::string_view dtype) + -> absl::StatusOr { if (dtype == "fp32") { return KernelCall::Parameter{static_cast(value)}; } else if (dtype == "fp64") { @@ -76,63 +78,68 @@ PYBIND11_MODULE(_triton, m) { return absl::InvalidArgumentError(std::string("unknown dtype: ") + dtype.data()); } - }); + })); - py::class_(m, "TritonKernelCall") - .def(py::init(m, "TritonKernelCall") + .def(nb::init>()) .def("to_proto", [](const KernelCall& kernel_call, std::string name, - std::string metadata) { + nb::bytes metadata) { jax_triton::TritonAnyKernelCall proto; *proto.mutable_kernel_call() = kernel_call.ToProto(); proto.set_name(std::move(name)); - proto.set_metadata(std::move(metadata)); - return py::bytes(proto.SerializeAsString()); + proto.set_metadata(metadata.c_str(), metadata.size()); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.c_str(), s.size()); }); - py::class_(m, "TritonAutotunedKernelCall") - .def(py::init<>([](std::string name, - std::vector> - calls_and_descriptions, - std::vector> - input_output_aliases) { - std::vector configs; - configs.reserve(calls_and_descriptions.size()); - for (auto& [kernel_call, desc] : calls_and_descriptions) { - configs.push_back({std::move(kernel_call), std::move(desc)}); - } - return std::make_unique( - std::move(name), std::move(configs), - std::move(input_output_aliases)); - })) + nb::class_(m, "TritonAutotunedKernelCall") + .def("__init__", + [](AutotunedKernelCall* call, std::string name, + std::vector> + calls_and_descriptions, + std::vector> + input_output_aliases) { + std::vector configs; + configs.reserve(calls_and_descriptions.size()); + for (auto& [kernel_call, desc] : calls_and_descriptions) { + configs.push_back({std::move(kernel_call), std::move(desc)}); + } + new (call) AutotunedKernelCall(std::move(name), std::move(configs), + std::move(input_output_aliases)); + }) .def("to_proto", [](const AutotunedKernelCall& kernel_call, - std::string name, std::string metadata) { + std::string name, nb::bytes metadata) { jax_triton::TritonAnyKernelCall proto; *proto.mutable_autotuned_kernel_call() = kernel_call.ToProto(); proto.set_name(std::move(name)); - proto.set_metadata(std::move(metadata)); - return py::bytes(proto.SerializeAsString()); + proto.set_metadata(metadata.c_str(), metadata.size()); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.c_str(), s.size()); }); m.def("get_custom_call", [] { return EncapsulateFunction(&TritonKernelCall); }); - m.def("get_compute_capability", [](int device) -> absl::StatusOr { - int major, minor; - CUDA_RETURN_IF_ERROR(cuInit(device)); - CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute( - &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)); - CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute( - &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); - return major * 10 + minor; - }); + m.def("get_compute_capability", + ValueOrThrowWrapper([](int device) -> absl::StatusOr { + int major, minor; + CUDA_RETURN_IF_ERROR(cuInit(device)); + CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute( + &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)); + CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute( + &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); + return major * 10 + minor; + })); m.def("get_serialized_metadata", - [](absl::string_view opaque) -> absl::StatusOr { - JAX_ASSIGN_OR_RETURN(std::string metadata, - GetTritonKernelCallSerializedMetadata(opaque)); - return py::bytes(metadata); - }); + ValueOrThrowWrapper( + [](std::string_view opaque) -> absl::StatusOr { + JAX_ASSIGN_OR_RETURN( + std::string metadata, + GetTritonKernelCallSerializedMetadata(opaque)); + return nb::bytes(metadata.c_str(), metadata.size()); + })); } } // namespace jax::JAX_GPU_NAMESPACE diff --git a/jaxlib/kernel_pybind11_helpers.h b/jaxlib/kernel_nanobind_helpers.h similarity index 53% rename from jaxlib/kernel_pybind11_helpers.h rename to jaxlib/kernel_nanobind_helpers.h index eedb21b8ab08..ef44a92f043f 100644 --- a/jaxlib/kernel_pybind11_helpers.h +++ b/jaxlib/kernel_nanobind_helpers.h @@ -13,34 +13,51 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_KERNEL_PYBIND11_HELPERS_H_ -#define JAXLIB_KERNEL_PYBIND11_HELPERS_H_ +#ifndef JAXLIB_KERNEL_NANOBIND_HELPERS_H_ +#define JAXLIB_KERNEL_NANOBIND_HELPERS_H_ -#include "pybind11/pybind11.h" +#include + +#include "nanobind/nanobind.h" #include "absl/base/casts.h" #include "jaxlib/kernel_helpers.h" +#include "tsl/python/lib/core/numpy.h" // NOLINT namespace jax { +// Caution: to use this type you must call tsl::ImportNumpy() in your module +// initialization function. Otherwise PyArray_DescrCheck will be nullptr. +class dtype : public nanobind::object { + public: + NB_OBJECT_DEFAULT(dtype, object, "dtype", PyArray_DescrCheck); // NOLINT + + ssize_t itemsize() const { return nanobind::cast(attr("itemsize")); } + + /// Single-character code for dtype's kind. + /// For example, floating point types are 'f' and integral types are 'i'. + char kind() const { return nanobind::cast(attr("kind")); } +}; + // Descriptor objects are opaque host-side objects used to pass data from JAX // to the custom kernel launched by XLA. Currently simply treat host-side // structures as byte-strings; this is not portable across architectures. If // portability is needed, we could switch to using a representation such as // protocol buffers or flatbuffers. -// Packs a descriptor object into a pybind11::bytes structure. +// Packs a descriptor object into a nanobind::bytes structure. // UnpackDescriptor() is available in kernel_helpers.h. template -pybind11::bytes PackDescriptor(const T& descriptor) { - return pybind11::bytes(PackDescriptorAsString(descriptor)); +nanobind::bytes PackDescriptor(const T& descriptor) { + std::string s = PackDescriptorAsString(descriptor); + return nanobind::bytes(s.data(), s.size()); } template -pybind11::capsule EncapsulateFunction(T* fn) { - return pybind11::capsule(absl::bit_cast(fn), +nanobind::capsule EncapsulateFunction(T* fn) { + return nanobind::capsule(absl::bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); } } // namespace jax -#endif // JAXLIB_KERNEL_PYBIND11_HELPERS_H_ +#endif // JAXLIB_KERNEL_NANOBIND_HELPERS_H_ diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index da1dcceaea12..4423c00f3767 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -92,12 +92,13 @@ pybind_extension( deps = [ ":hip_vendor", ":hipblas_kernels", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:rocm_headers", - "@pybind11", + "@nanobind", + "@tsl//tsl/python/lib/core:numpy", ], ) @@ -132,12 +133,13 @@ pybind_extension( ":hip_gpu_kernel_helpers", ":hip_vendor", ":hipsolver_kernels", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", - "@pybind11", + "@nanobind", + "@tsl//tsl/python/lib/core:numpy", ], ) @@ -172,7 +174,7 @@ pybind_extension( ":hip_gpu_kernel_helpers", ":hip_vendor", ":hipsparse_kernels", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -184,7 +186,8 @@ pybind_extension( "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", - "@pybind11", + "@nanobind", + "@tsl//tsl/python/lib/core:numpy", ], ) @@ -229,9 +232,9 @@ pybind_extension( ":hip_lu_pivot_kernels", ":hip_lu_pivot_kernels_impl", ":hip_vendor", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@local_config_rocm//rocm:rocm_headers", - "@pybind11", + "@nanobind", ], ) @@ -275,9 +278,9 @@ pybind_extension( ":hip_gpu_kernel_helpers", ":hip_prng_kernels", ":hip_vendor", - "//jaxlib:kernel_pybind11_helpers", + "//jaxlib:kernel_nanobind_helpers", "@local_config_rocm//rocm:rocm_headers", - "@pybind11", + "@nanobind", ], ) diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index 6d6f72ccb120..28201233566a 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -15,11 +15,11 @@ limitations under the License. #include -#include "pybind11/pybind11.h" +#include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" #include "absl/container/inlined_vector.h" -namespace py = pybind11; +namespace nb = nanobind; namespace { @@ -32,12 +32,12 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { return nullptr; } PyObject* fn = args[0]; - absl::InlinedVector iterators; + absl::InlinedVector iterators; iterators.reserve(nargs - 1); for (Py_ssize_t i = 1; i < nargs; ++i) { PyObject* it = PyObject_GetIter(args[i]); if (!it) return nullptr; - iterators.push_back(py::reinterpret_steal(it)); + iterators.push_back(nb::steal(it)); } // Try to use a length hint to estimate how large a list to allocate. @@ -49,7 +49,7 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { length_hint = 2; } - py::list list(length_hint); + nb::list list = nb::steal(PyList_New(length_hint)); int n = 0; // Current true size of the list // The arguments we will pass to fn. We allocate space for one more argument @@ -100,7 +100,7 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { return list.release().ptr(); } - py::object out = py::reinterpret_steal(PyObject_Vectorcall( + nb::object out = nb::steal(PyObject_Vectorcall( fn, &values[1], (nargs - 1) | PY_VECTORCALL_ARGUMENTS_OFFSET, /*kwnames=*/nullptr)); if (PyErr_Occurred()) { @@ -135,12 +135,12 @@ PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { PyErr_SetString(PyExc_TypeError, "safe_zip requires at least 1 argument"); return nullptr; } - absl::InlinedVector iterators; + absl::InlinedVector iterators; iterators.reserve(nargs); for (Py_ssize_t i = 0; i < nargs; ++i) { PyObject* it = PyObject_GetIter(args[i]); if (!it) return nullptr; - iterators.push_back(py::reinterpret_steal(it)); + iterators.push_back(nb::steal(it)); } // Try to use a length hint to estimate how large a list to allocate. @@ -152,22 +152,21 @@ PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { length_hint = 2; } - py::list list(length_hint); + nb::list list = nb::steal(PyList_New(length_hint)); int n = 0; // Current true size of the list while (true) { - py::object tuple; - py::object v = - py::reinterpret_steal(PyIter_Next(iterators[0].ptr())); + nb::object tuple; + nb::object v = nb::steal(PyIter_Next(iterators[0].ptr())); if (PyErr_Occurred()) return nullptr; if (v.ptr()) { - tuple = py::reinterpret_steal(PyTuple_New(nargs)); + tuple = nb::steal(PyTuple_New(nargs)); if (!tuple.ptr()) return nullptr; PyTuple_SET_ITEM(tuple.ptr(), 0, v.release().ptr()); for (size_t i = 1; i < iterators.size(); ++i) { - v = py::reinterpret_steal(PyIter_Next(iterators[i].ptr())); + v = nb::steal(PyIter_Next(iterators[i].ptr())); if (PyErr_Occurred()) return nullptr; if (!v.ptr()) { PyErr_Format(PyExc_ValueError, @@ -181,7 +180,7 @@ PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { // No more elements should be left. Checks the other iterators are // exhausted. for (size_t i = 1; i < iterators.size(); ++i) { - v = py::reinterpret_steal(PyIter_Next(iterators[i].ptr())); + v = nb::steal(PyIter_Next(iterators[i].ptr())); if (PyErr_Occurred()) return nullptr; if (v.ptr()) { PyErr_Format(PyExc_ValueError, @@ -206,7 +205,7 @@ PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { if (PyList_Append(list.ptr(), tuple.ptr()) < 0) { return nullptr; } - tuple = py::object(); + tuple = nb::object(); } ++n; } @@ -220,11 +219,10 @@ PyMethodDef safe_zip_def = { } // namespace - -PYBIND11_MODULE(utils, m) { - py::object module_name = m.attr("__name__"); - m.attr("safe_map") = py::reinterpret_steal( +NB_MODULE(utils, m) { + nb::object module_name = m.attr("__name__"); + m.attr("safe_map") = nb::steal( PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr())); - m.attr("safe_zip") = py::reinterpret_steal( + m.attr("safe_zip") = nb::steal( PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr())); } \ No newline at end of file diff --git a/third_party/nanobind/BUILD.bazel b/third_party/nanobind/BUILD.bazel new file mode 100644 index 000000000000..fa975bc2d002 --- /dev/null +++ b/third_party/nanobind/BUILD.bazel @@ -0,0 +1,22 @@ +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "nanobind", + srcs = glob([ + "src/*.cpp", + ]), + copts = ["-fexceptions"], + includes = ["include"], + textual_hdrs = glob( + [ + "include/**/*.h", + "src/*.h", + ], + ), + deps = [ + "@robin_map", + "@xla//third_party/python_runtime:headers", + ], +) diff --git a/third_party/nanobind/workspace.bzl b/third_party/nanobind/workspace.bzl new file mode 100644 index 000000000000..5f0749e1b00a --- /dev/null +++ b/third_party/nanobind/workspace.bzl @@ -0,0 +1,26 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loads the nanobind library.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + tf_http_archive( + name = "nanobind", + strip_prefix = "nanobind-1.5.0", + sha256 = "fe9d0bfe89b6514eed56a3f223ab257edbaf4fcd322c2acd187901cc2d212596", + urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v1.5.0.tar.gz"), + build_file = "//third_party/nanobind:BUILD.bazel", + ) diff --git a/third_party/robin_map/BUILD.bazel b/third_party/robin_map/BUILD.bazel new file mode 100644 index 000000000000..b649dda31766 --- /dev/null +++ b/third_party/robin_map/BUILD.bazel @@ -0,0 +1,17 @@ +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "robin_map", + hdrs = [ + "include/tsl/robin_growth_policy.h", + "include/tsl/robin_hash.h", + "include/tsl/robin_map.h", + "include/tsl/robin_set.h", + ], + copts = ["-fexceptions"], + features = ["-use_header_modules"], # Incompatible with -fexceptions. + includes = ["."], + strip_include_prefix = "include", +) diff --git a/third_party/robin_map/workspace.bzl b/third_party/robin_map/workspace.bzl new file mode 100644 index 000000000000..3b16856b0014 --- /dev/null +++ b/third_party/robin_map/workspace.bzl @@ -0,0 +1,26 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loads the robin_map library.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + tf_http_archive( + name = "robin_map", + strip_prefix = "robin-map-1.2.1", + sha256 = "2b54d2c1de2f73bea5c51d5dcbd64813a08caf1bfddcfdeee40ab74e9599e8e3", + urls = tf_mirror_urls("https://github.com/Tessil/robin-map/archive/refs/tags/v1.2.1.tar.gz"), + build_file = "//third_party/robin_map:BUILD.bazel", + )