Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention v2 forward #10484

Merged
merged 27 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2e16534
updata cmake files for flash-attention.
cccddd77 Mar 25, 2024
f8bcb78
update flash_attention.cmake
cccddd77 Mar 29, 2024
c74d206
add flash-attention 2 forward kernel.
cccddd77 Mar 29, 2024
79a4e39
Merge branch 'Oneflow-Inc:master' into flash_attn_v2
cccddd77 Mar 29, 2024
60d616b
fix bugs.
cccddd77 Apr 1, 2024
45f9ee0
Merge branch 'flash_attn_v2' of https://github.com/cccddd77/oneflow i…
cccddd77 Apr 1, 2024
78aa88b
format check.
cccddd77 Apr 1, 2024
599f11f
fix bug.
cccddd77 Apr 2, 2024
0062b64
add output tensor for rng_state.
cccddd77 Apr 9, 2024
5058ebd
fix typo.
cccddd77 Apr 11, 2024
7fbaf6a
rm duplicated license.
cccddd77 Apr 11, 2024
bb68ee1
format check.
cccddd77 Apr 11, 2024
996c00c
cmake format.
cccddd77 Apr 11, 2024
6aa9f1b
update flash-attention forward sbp.
cccddd77 Apr 15, 2024
803ae51
add cuda_version check, flash-attention only support cuda_version >=
cccddd77 Apr 15, 2024
a55977d
fix cmake bugs.
cccddd77 Apr 15, 2024
d94c500
auto format by CI
oneflow-ci-bot Apr 15, 2024
df1c97a
rm useless op trait of flash-attn.
cccddd77 Apr 16, 2024
2e641f4
change cuda version requirement to >= 11.7
cccddd77 Apr 16, 2024
a6780e2
add test file.
cccddd77 Apr 16, 2024
7b2e632
typo.
cccddd77 Apr 16, 2024
359cc12
format check.
cccddd77 Apr 16, 2024
a83daf8
Merge branch 'master' into flash_attention_v2_forward
cccddd77 Apr 16, 2024
a99d71e
use git to get flash-attention.
cccddd77 Apr 17, 2024
ffe3f74
fix test bug on cpu only env.
cccddd77 Apr 17, 2024
fcc088f
auto format by CI
oneflow-ci-bot Apr 17, 2024
4c06137
change test file.
cccddd77 Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/oneflow.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,11 @@ if(BUILD_CPP_API)
checkdirandappendslash(DIR ${TRT_FLASH_ATTENTION_LIBRARY_DIR} OUTPUT
TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED)
list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED})
if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7")
checkdirandappendslash(DIR ${FLASH_ATTENTION_LIBRARY_DIR} OUTPUT
FLASH_ATTENTION_LIBRARY_DIR_APPENDED)
list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${FLASH_ATTENTION_LIBRARY_DIR_APPENDED})
endif()
if(WITH_CUTLASS)
checkdirandappendslash(DIR ${CUTLASS_LIBRARY_DIR} OUTPUT CUTLASS_LIBRARY_DIR_APPENDED)
list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${CUTLASS_LIBRARY_DIR_APPENDED})
Expand Down
8 changes: 8 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ if(BUILD_CUDA)
include(nccl)
include(cutlass)
include(trt_flash_attention)
if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7")
include(flash_attention)
endif()

list(APPEND oneflow_third_party_libs ${NCCL_LIBRARIES})
list(APPEND oneflow_third_party_libs ${CUDNN_LIBRARIES})
Expand All @@ -164,6 +167,11 @@ if(BUILD_CUDA)
list(APPEND oneflow_third_party_dependencies trt_flash_attention)
list(APPEND oneflow_third_party_libs ${TRT_FLASH_ATTENTION_LIBRARIES})
list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${TRT_FLASH_ATTENTION_INCLUDE_DIR})
if(CUDA_VERSION VERSION_GREATER_EQUAL "11.7")
list(APPEND oneflow_third_party_dependencies flash_attention)
list(APPEND oneflow_third_party_libs ${FLASH_ATTENTION_LIBRARIES})
list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${FLASH_ATTENTION_INCLUDE_DIR})
endif()
endif()

if(BUILD_RDMA)
Expand Down
39 changes: 39 additions & 0 deletions cmake/third_party/flash_attention.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
include(ExternalProject)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个第三方库 之后最好把它挪到external目录下去 原则上新添加的第三方依赖都放external下


find_package(Threads)

# NOTE: A git version of 1.6.5 or later is required if this download method is used.
find_package(Git QUIET REQUIRED)

set(FLASH_ATTENTION_PROJECT flash_attention)

set(FLASH_ATTENTION_URL https://github.com/Oneflow-Inc/flash-attention-v2.git)
set(FLASH_ATTENTION_TAG eed2e82b880e06237af3e50ceac4cf6728b15645)

set(FLASH_ATTENTION_INSTALL_DIR ${THIRD_PARTY_DIR}/flash_attention)
set(FLASH_ATTENTION_INCLUDE_DIR ${FLASH_ATTENTION_INSTALL_DIR}/include CACHE PATH "" FORCE)
set(FLASH_ATTENTION_LIBRARY_DIR ${FLASH_ATTENTION_INSTALL_DIR}/lib CACHE PATH "" FORCE)
set(FLASH_ATTENTION_LIBRARIES ${FLASH_ATTENTION_LIBRARY_DIR}/libflash_attention.so)

if(THIRD_PARTY)
ExternalProject_Add(
${FLASH_ATTENTION_PROJECT}
PREFIX flash_attention
GIT_REPOSITORY ${FLASH_ATTENTION_URL}
GIT_TAG ${FLASH_ATTENTION_TAG}
UPDATE_COMMAND ""
BUILD_BYPRODUCTS ${FLASH_ATTENTION_LIBRARIES}
CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}
-DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CUDA_ARCHITECTURES:STRING=${CMAKE_CUDA_ARCHITECTURES}
CMAKE_CACHE_ARGS
-DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE}
-DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}
-DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}
-DCMAKE_INSTALL_PREFIX:PATH=${FLASH_ATTENTION_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${FLASH_ATTENTION_LIBRARY_DIR}
-DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}
-DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE})
endif(THIRD_PARTY)
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2684,6 +2684,10 @@
signature: "TensorTuple (Tensor x, Tensor bias, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedBiasAddScaleMaskSoftmaxDropout"
bind_python: True

- name: "scaled_dot_product_attention"
signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention"
bind_python: True

- name: "fused_multi_head_attention_inference"
signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference"
bind_python: True
Expand Down
119 changes: 119 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "oneflow/user/kernels/dropout_kernel.h"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/user/kernels/random_seed_util.h"
#include "oneflow/user/kernels/scaled_dot_product_attention_kernel.h"

#include "oneflow/core/common/container_util.h"
#include "fmt/core.h"
Expand Down Expand Up @@ -5420,6 +5421,123 @@ class NonContiguousBinaryOpGradFunctor {
std::shared_ptr<OpExpr> op_;
};

namespace {

template<int alignment_size>
Maybe<one::Tensor> pad_last_dim(const std::shared_ptr<one::Tensor>& input) {
auto num_dims = input->shape()->NumAxes();
auto last_dim_size = input->shape()->At(num_dims - 1);
if (last_dim_size % alignment_size == 0) { return input; }
auto pad_count = alignment_size - (last_dim_size % alignment_size);

return JUST(functional::Pad(input, {0, pad_count}, "constant", Scalar(0)));
;
}

} // namespace

class ScaledDotProductFlashAttentionFunctor {
public:
ScaledDotProductFlashAttentionFunctor() {
#if CUDA_VERSION >= 11070
op_ = CHECK_JUST(one::OpBuilder("scaled_dot_product_flash_attention")
.Input("query")
.Input("key")
.Input("value")
.Output("out")
.Output("softmax_lse")
.Output("rng_state")
.Build());
#endif // CUDA_VERSION >= 11070
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& query,
const std::shared_ptr<one::Tensor>& key,
const std::shared_ptr<one::Tensor>& value,
const Optional<one::Tensor>& attn_mask, const float& dropout_p,
const bool& is_causal, const Optional<float>& scale,
const int64_t& seed = 0) const {
#if CUDA_VERSION >= 11070
const auto og_size = query->shape()->At(3);
const auto batch_size = query->shape()->At(0);
const auto seqlen_q = query->shape()->At(2);
const auto num_heads = query->shape()->At(1);
const auto num_heads_k = key->shape()->At(1);
const auto max_seqlen_batch_k = key->shape()->At(2);
const auto max_seqlen_batch_v = value->shape()->At(2);

CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0))
<< " key has different batch size from query.";
CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0))
<< " value has different batch size from query.";
CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(1))
<< " value has different num_heads from key.";
CHECK_EQ_OR_RETURN(max_seqlen_batch_k, max_seqlen_batch_v)
<< "value has different seqlen from key.";
CHECK_EQ_OR_RETURN(og_size, key->shape()->At(3)) << " key has different head dims from query.";
CHECK_EQ_OR_RETURN(og_size, value->shape()->At(3))
<< " value has different head dims from query.";

// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
std::shared_ptr<Tensor> q_padded, k_padded, v_padded;
bool padded = og_size % 8;
if (padded) {
q_padded = JUST(pad_last_dim<8>(query));
k_padded = JUST(pad_last_dim<8>(key));
v_padded = JUST(pad_last_dim<8>(value));
} else {
q_padded = query;
k_padded = key;
v_padded = value;
}

auto q_ = JUST(functional::Transpose(q_padded, {0, 2, 1, 3}));
auto k_ = JUST(functional::Transpose(k_padded, {0, 2, 1, 3}));
auto v_ = JUST(functional::Transpose(v_padded, {0, 2, 1, 3}));
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key -> Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)

const auto& scale_ =
scale.has_value() ? scale : (1.0f / std::sqrt(static_cast<float>(query->shape()->At(3))));

auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p_dropout", "softmax_scale", "is_causal",
"window_size_left", "window_size_right", "seed");
attrs.SetAllAttrs(dropout_p, scale_, is_causal, -1, -1, seed);

auto gen = JUST(one::DefaultAutoGenerator());
gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), query));
const auto& state = std::make_shared<ScaledDotProductFlashAttentionKernelState>(gen);
OpExprInterpContext ctx(attrs, state);

std::shared_ptr<one::Tensor> output_ =
JUST(OpInterpUtil::Dispatch<one::Tensor>(*op_, {q_, k_, v_}, ctx));

auto output_padded = JUST(functional::Transpose(output_, {0, 2, 1, 3}));

std::shared_ptr<Tensor> output;
if (padded) {
output =
JUST(functional::Slice(output_padded, {0, 0, 0, 0},
{batch_size, num_heads, seqlen_q, og_size}, {1, 1, 1, 1}, false));
} else {
output = output_padded;
}

return output;
#endif // CUDA_VERSION >= 11070

UNIMPLEMENTED_THEN_RETURN() << "only support CUDA_VERSION >= 11070.";
}

private:
#if CUDA_VERSION >= 11070
std::shared_ptr<OpExpr> op_;
#endif // CUDA_VERSION >= 11070
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -5557,6 +5675,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::NonContiguousBinaryOpGradFunctor>("NonContiguousBinaryOpGrad");
m.add_functor<impl::MultiTensorYoloV5WeightUpdateFunctor>("MultiTensorYoloV5WeightUpdate");
m.add_functor<impl::FusedClipGradFunctor>("FusedClipGrad");
m.add_functor<impl::ScaledDotProductFlashAttentionFunctor>("ScaledDotProductFlashAttention");
}

} // namespace functional
Expand Down
26 changes: 26 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2877,6 +2877,32 @@ def OneFlow_FusedCrossFeatureInteractionV2GradOp : OneFlow_BaseOp<"fused_cross_f
let has_data_type_infer_fn = 1;
}

def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$query,
OneFlow_Tensor:$key,
OneFlow_Tensor:$value,
Optional<OneFlow_Tensor>:$alibi_slopes_
);
let output = (outs
OneFlow_Tensor:$out,
OneFlow_Tensor:$softmax_lse,
OneFlow_Tensor:$rng_state
);
let attrs = (ins
DefaultValuedAttr<F32Attr, "0.">:$p_dropout,
DefaultValuedAttr<F32Attr, "0.">:$softmax_scale,
DefaultValuedAttr<BoolAttr, "false">:$is_causal,
SI32Attr:$window_size_left,
SI32Attr:$window_size_right,
DefaultValuedAttr<SI64Attr, "0">:$seed
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$query,
Expand Down
Loading
Loading