Skip to content

Commit

Permalink
Rearrange the LAPACK handler definitions in jaxlib to avoid duplicate…
Browse files Browse the repository at this point in the history
… handler errors.

When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.

PiperOrigin-RevId: 657186057
  • Loading branch information
dfm authored and Rifur13 committed Jul 29, 2024
1 parent 316d0fe commit bdc93d9
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 152 deletions.
2 changes: 0 additions & 2 deletions jaxlib/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ cc_library(
pybind_extension(
name = "_lapack",
srcs = ["lapack.cc"],
hdrs = ["lapack.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
Expand All @@ -77,7 +76,6 @@ pybind_extension(
cc_library(
name = "cpu_kernels",
srcs = ["cpu_kernels.cc"],
hdrs = ["lapack.h"],
visibility = ["//visibility:public"],
deps = [
":lapack_kernels",
Expand Down
1 change: 0 additions & 1 deletion jaxlib/cpu/cpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.

#include <complex>

#include "jaxlib/cpu/lapack.h"
#include "jaxlib/cpu/lapack_kernels.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
Expand Down
2 changes: 0 additions & 2 deletions jaxlib/cpu/lapack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "jaxlib/cpu/lapack.h"

#include <complex>

#include "nanobind/nanobind.h"
Expand Down
147 changes: 0 additions & 147 deletions jaxlib/cpu/lapack.h

This file was deleted.

122 changes: 122 additions & 0 deletions jaxlib/cpu/lapack_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1387,4 +1387,126 @@ template struct Sytrd<double>;
template struct Sytrd<std::complex<float>>;
template struct Sytrd<std::complex<double>>;

// FFI Definition Macros (by DataType)

#define JAX_CPU_DEFINE_TRSM(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, TriMatrixEquationSolver<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
.Arg<::xla::ffi::BufferR0<data_type>>(/*alpha*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
.Attr<MatrixParams::Side>("side") \
.Attr<MatrixParams::UpLo>("uplo") \
.Attr<MatrixParams::Transpose>("trans_x") \
.Attr<MatrixParams::Diag>("diag"))

#define JAX_CPU_DEFINE_GETRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, LuDecomposition<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*ipiv*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))

#define JAX_CPU_DEFINE_GEQRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, QrFactorization<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/))

#define JAX_CPU_DEFINE_ORGQR(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, OrthogonalQr<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Arg<::xla::ffi::Buffer<data_type>>(/*tau*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/))

#define JAX_CPU_DEFINE_POTRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, CholeskyFactorization<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<MatrixParams::UpLo>("uplo") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))

#define JAX_CPU_DEFINE_GESDD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SingularValueDecomposition<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*s*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*u*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*vt*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
.Attr<svd::ComputationMode>("mode"))

#define JAX_CPU_DEFINE_GESDD_COMPLEX(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SingularValueDecompositionComplex<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*u*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*vt*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
.Attr<svd::ComputationMode>("mode"))

// FFI Handlers

JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GETRF(lapack_zgetrf_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_GEQRF(lapack_sgeqrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEQRF(lapack_dgeqrf_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEQRF(lapack_cgeqrf_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GEQRF(lapack_zgeqrf_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_ORGQR(lapack_sorgqr_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_ORGQR(lapack_dorgqr_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_ORGQR(lapack_cungqr_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_ORGQR(lapack_zungqr_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_GESDD(lapack_sgesdd_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);

#undef JAX_CPU_DEFINE_TRSM
#undef JAX_CPU_DEFINE_GETRF
#undef JAX_CPU_DEFINE_GEQRF
#undef JAX_CPU_DEFINE_ORGQR
#undef JAX_CPU_DEFINE_POTRF
#undef JAX_CPU_DEFINE_GESDD
#undef JAX_CPU_DEFINE_GESDD_COMPLEX

} // namespace jax
26 changes: 26 additions & 0 deletions jaxlib/cpu/lapack_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,32 @@ struct Sytrd {
static int64_t Workspace(lapack_int lda, lapack_int n);
};

// Declare all the handler symbols
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_strsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_dtrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ctrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ztrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgetrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgetrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgetrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgetrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeqrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeqrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeqrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeqrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sorgqr_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dorgqr_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cungqr_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zungqr_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_spotrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dpotrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cpotrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zpotrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesdd_ffi);

} // namespace jax

#endif // JAXLIB_CPU_LAPACK_KERNELS_H_

0 comments on commit bdc93d9

Please sign in to comment.