Skip to content

Commit

Permalink
Add conditional compiling for cuda-depending codes in ROCM
Browse files Browse the repository at this point in the history
  • Loading branch information
qianfengz committed Feb 6, 2024
1 parent 28d3672 commit 12fb41c
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 0 deletions.
2 changes: 2 additions & 0 deletions xformers/csrc/attention/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ at::Tensor matmul_with_mask(
}

TORCH_LIBRARY_FRAGMENT(xformers, m) {
#if !defined(USE_ROCM)
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::matmul_with_mask(Tensor a, Tensor b, Tensor mask) -> Tensor"));
#endif
}

TORCH_LIBRARY_IMPL(xformers, CPU, m) {
Expand Down
2 changes: 2 additions & 0 deletions xformers/csrc/attention/sddmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <torch/types.h>

TORCH_LIBRARY_FRAGMENT(xformers, m) {
#if !defined(USE_ROCM)
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::sddmm_sputnik(Tensor a, Tensor b, Tensor row_indices, Tensor row_offsets, Tensor column_indices) -> Tensor"));
#endif
}
2 changes: 2 additions & 0 deletions xformers/csrc/attention/sparse_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include <torch/types.h>

TORCH_LIBRARY_FRAGMENT(xformers, m) {
#if !defined(USE_ROCM)
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::sparse_softmax_sputnik(int m, int n, Tensor row_indices, Tensor values, Tensor row_offsets, Tensor column_indices) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::sparse_softmax_backward_sputnik(int m, int n, Tensor row_indices, Tensor values, Tensor gradient, Tensor row_offsets, Tensor column_indices) -> Tensor"));
#endif
}
2 changes: 2 additions & 0 deletions xformers/csrc/attention/spmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <torch/types.h>

TORCH_LIBRARY_FRAGMENT(xformers, m) {
#if !defined(USE_ROCM)
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::spmm_sputnik(Tensor b, Tensor row_indices, Tensor values, Tensor row_offsets, Tensor column_indices, int m) -> Tensor"));
#endif
}
2 changes: 2 additions & 0 deletions xformers/csrc/swiglu/swiglu_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
#include <torch/types.h>

TORCH_LIBRARY_FRAGMENT(xformers, m) {
#if !defined(USE_ROCM)
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::dual_gemm_silu_identity_mul(Tensor x, Tensor w1, Tensor? b1, Tensor w2, Tensor? b2) -> (Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::silu_bw_fused(Tensor x1, Tensor x2, Tensor dx4) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::gemm_fused_operand_sum(Tensor a, Tensor b, Tensor out_mm, Tensor out_sum) -> (Tensor, Tensor)"));
#endif
}
2 changes: 2 additions & 0 deletions xformers/csrc/swiglu/swiglu_packedw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,10 @@ at::Tensor swiglu_packedw_cuda(
} // namespace

TORCH_LIBRARY(xformers, m) {
#if !defined(USE_ROCM)
m.def(
"swiglu_packedw(Tensor x, Tensor w1w2, Tensor? b1b2, Tensor w3, Tensor? b3) -> Tensor");
#endif
}

TORCH_LIBRARY_IMPL(xformers, Autograd, m) {
Expand Down

0 comments on commit 12fb41c

Please sign in to comment.