diff --git a/xformers/csrc/attention/matmul.cpp b/xformers/csrc/attention/matmul.cpp index 2841912639..e5c7deb1d4 100644 --- a/xformers/csrc/attention/matmul.cpp +++ b/xformers/csrc/attention/matmul.cpp @@ -35,8 +35,10 @@ at::Tensor matmul_with_mask( } TORCH_LIBRARY_FRAGMENT(xformers, m) { +#if !defined(USE_ROCM) m.def(TORCH_SELECTIVE_SCHEMA( "xformers::matmul_with_mask(Tensor a, Tensor b, Tensor mask) -> Tensor")); +#endif } TORCH_LIBRARY_IMPL(xformers, CPU, m) { diff --git a/xformers/csrc/attention/sddmm.cpp b/xformers/csrc/attention/sddmm.cpp index 7b5e7e3307..f4b810b0af 100644 --- a/xformers/csrc/attention/sddmm.cpp +++ b/xformers/csrc/attention/sddmm.cpp @@ -9,6 +9,8 @@ #include TORCH_LIBRARY_FRAGMENT(xformers, m) { +#if !defined(USE_ROCM) m.def(TORCH_SELECTIVE_SCHEMA( "xformers::sddmm_sputnik(Tensor a, Tensor b, Tensor row_indices, Tensor row_offsets, Tensor column_indices) -> Tensor")); +#endif } diff --git a/xformers/csrc/attention/sparse_softmax.cpp b/xformers/csrc/attention/sparse_softmax.cpp index 826e3641e8..074e670e3f 100644 --- a/xformers/csrc/attention/sparse_softmax.cpp +++ b/xformers/csrc/attention/sparse_softmax.cpp @@ -9,8 +9,10 @@ #include TORCH_LIBRARY_FRAGMENT(xformers, m) { +#if !defined(USE_ROCM) m.def(TORCH_SELECTIVE_SCHEMA( "xformers::sparse_softmax_sputnik(int m, int n, Tensor row_indices, Tensor values, Tensor row_offsets, Tensor column_indices) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::sparse_softmax_backward_sputnik(int m, int n, Tensor row_indices, Tensor values, Tensor gradient, Tensor row_offsets, Tensor column_indices) -> Tensor")); +#endif } diff --git a/xformers/csrc/attention/spmm.cpp b/xformers/csrc/attention/spmm.cpp index fbe7e1bf9c..06271e6c09 100644 --- a/xformers/csrc/attention/spmm.cpp +++ b/xformers/csrc/attention/spmm.cpp @@ -9,6 +9,8 @@ #include TORCH_LIBRARY_FRAGMENT(xformers, m) { +#if !defined(USE_ROCM) m.def(TORCH_SELECTIVE_SCHEMA( "xformers::spmm_sputnik(Tensor b, Tensor row_indices, Tensor values, Tensor row_offsets, Tensor column_indices, int m) -> Tensor")); +#endif } diff --git a/xformers/csrc/swiglu/swiglu_op.cpp b/xformers/csrc/swiglu/swiglu_op.cpp index a8880acf6a..6f1ef4d7ad 100644 --- a/xformers/csrc/swiglu/swiglu_op.cpp +++ b/xformers/csrc/swiglu/swiglu_op.cpp @@ -8,10 +8,12 @@ #include TORCH_LIBRARY_FRAGMENT(xformers, m) { +#if !defined(USE_ROCM) m.def(TORCH_SELECTIVE_SCHEMA( "xformers::dual_gemm_silu_identity_mul(Tensor x, Tensor w1, Tensor? b1, Tensor w2, Tensor? b2) -> (Tensor, Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::silu_bw_fused(Tensor x1, Tensor x2, Tensor dx4) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::gemm_fused_operand_sum(Tensor a, Tensor b, Tensor out_mm, Tensor out_sum) -> (Tensor, Tensor)")); +#endif } diff --git a/xformers/csrc/swiglu/swiglu_packedw.cpp b/xformers/csrc/swiglu/swiglu_packedw.cpp index 00fbef12a4..65e3e22a82 100644 --- a/xformers/csrc/swiglu/swiglu_packedw.cpp +++ b/xformers/csrc/swiglu/swiglu_packedw.cpp @@ -221,8 +221,10 @@ at::Tensor swiglu_packedw_cuda( } // namespace TORCH_LIBRARY(xformers, m) { +#if !defined(USE_ROCM) m.def( "swiglu_packedw(Tensor x, Tensor w1w2, Tensor? b1b2, Tensor w3, Tensor? b3) -> Tensor"); +#endif } TORCH_LIBRARY_IMPL(xformers, Autograd, m) {