Skip to content

Commit

Permalink
TorchFX: flash_attention support with SoftMax optimization (GPU)
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Oct 10, 2023
1 parent 8fbe26c commit 2ca7f0f
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self):
"torch.ops.aten.relu.default": None,
"torch.ops.aten.relu_.default": None,
"torch.ops.aten.rsub.Scalar": None,
"torch.ops.aten._scaled_dot_product_flash_attention.default": None,
"torch.ops.aten.select.int": None,
"torch.ops.aten.sigmoid.default": None,
"torch.ops.aten.silu.default": None,
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
//manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
manager.register_pass<ov::pass::ResolveNameCollisions>();
Expand Down
98 changes: 97 additions & 1 deletion src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
Expand All @@ -15,6 +16,7 @@
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/softmax.hpp"
Expand Down Expand Up @@ -103,7 +105,101 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context)
return {context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value))};
};

OutputVector translate_scaled_dot_product_attention_fx(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False)
num_inputs_check(context, 3, 6);
auto query = context.get_input(0);
auto key = context.get_input(1);
auto value = context.get_input(2);
auto q_shape = context.mark_node(std::make_shared<v3::ShapeOf>(query, element::i32));
auto k_shape = context.mark_node(std::make_shared<v3::ShapeOf>(key, element::i32));
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto minus_two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-2}));
auto zero_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto scale = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_one, zero_i));
scale = context.mark_node(std::make_shared<v1::ConvertLike>(scale, query));
auto sqrt_scale = context.mark_node(std::make_shared<v0::Sqrt>(scale));
auto one_f = context.mark_node(std::make_shared<v1::ConvertLike>(one_i, sqrt_scale));
auto zero_f = context.mark_node(std::make_shared<v1::ConvertLike>(zero_i, sqrt_scale));
scale = context.mark_node(std::make_shared<v1::Divide>(one_f, sqrt_scale));
auto q_scaled = context.mark_node(std::make_shared<v1::Multiply>(query, scale));
auto k_rank = context.mark_node(std::make_shared<v3::ShapeOf>(k_shape, element::i32));
auto k_last_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_one));
auto k_next_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_two));
k_rank = context.mark_node(std::make_shared<v0::Squeeze>(k_rank, zero_i));
auto minus_inf =
context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
auto keep_dim_last = context.mark_node(std::make_shared<v0::Squeeze>(k_next_dim, zero_i));
auto k_dims_before_transpose =
context.mark_node(std::make_shared<v4::Range>(zero_i, keep_dim_last, one_i, element::i32));

auto transpose_dims = context.mark_node(
std::make_shared<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0));
auto k_transposed = context.mark_node(std::make_shared<v1::Transpose>(key, transpose_dims));
auto scaled_atten = context.mark_node(std::make_shared<v0::MatMul>(q_scaled, k_transposed));
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten));
// two types of masks are supported. A boolean mask where a value of True indicates that the element should take
// part in attention. A float mask of the same type as query, key, value that is added to the attention score.
//auto is_causal = context.const_input<bool>(5);
auto is_causal = false;
if (is_causal || !context.input_is_none(3)) {
Output<Node> mask;
Output<Node> atten_mask;
if (!context.input_is_none(3)) {
mask = context.get_input(3);
if (mask.get_element_type() == element::boolean) {
atten_mask = context.mark_node(std::make_shared<v1::ConvertLike>(mask, scaled_atten));
auto inv_mask = context.mark_node(std::make_shared<v1::LogicalNot>(mask));
atten_mask = context.mark_node(std::make_shared<v1::Select>(inv_mask, atten_mask, minus_inf));
} else {
atten_mask = mask;
}
} else {
auto target_s_len = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_two, zero_i));
auto source_s_len = context.mark_node(std::make_shared<v8::Gather>(k_shape, minus_two, zero_i));
auto ssl = context.mark_node(std::make_shared<v0::Unsqueeze>(source_s_len, zero_i));
auto tsl = context.mark_node(std::make_shared<v0::Unsqueeze>(target_s_len, zero_i));
auto mask_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{tsl, ssl}, 0));
mask = context.mark_node(std::make_shared<v1::Broadcast>(minus_inf, mask_shape));
auto horizontal_range =
context.mark_node(std::make_shared<v4::Range>(zero_i, source_s_len, one_i, element::i32));
horizontal_range = context.mark_node(std::make_shared<v0::Unsqueeze>(horizontal_range, zero_i));
auto stop = context.mark_node(std::make_shared<v1::Add>(target_s_len, one_i));
auto vertical_range = context.mark_node(std::make_shared<v4::Range>(one_i, stop, one_i, element::i32));
vertical_range = context.mark_node(std::make_shared<v0::Unsqueeze>(vertical_range, one_i));
auto triu = context.mark_node(std::make_shared<v1::GreaterEqual>(horizontal_range, vertical_range));
atten_mask = context.mark_node(std::make_shared<v1::Select>(triu, mask, zero_f));
}
scaled_atten = context.mark_node(std::make_shared<v1::Add>(scaled_atten, atten_mask));
}
std::vector<int> softmax_shape_v(3);
softmax_shape_v[0] = 16;
softmax_shape_v[1] = scaled_atten->get_shape()[2];
softmax_shape_v[2] = scaled_atten->get_shape()[3];
std::vector<int> matmul_shape_v(4);
matmul_shape_v[0] = 2;
matmul_shape_v[1] = 8;
matmul_shape_v[2] = scaled_atten->get_shape()[2];
matmul_shape_v[3] = scaled_atten->get_shape()[3];
auto softmax_shape = context.mark_node(v0::Constant::create(element::i32, Shape{3}, softmax_shape_v));
auto reshape_3d = context.mark_node(std::make_shared<v1::Reshape>(scaled_atten, softmax_shape, false));
scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(reshape_3d, -1));
auto matmul_shape = context.mark_node(v0::Constant::create(element::i32, Shape{4}, matmul_shape_v));
auto reshape_4d = context.mark_node(std::make_shared<v1::Reshape>(scaled_atten, matmul_shape, true));
auto output = context.mark_node(std::make_shared<v0::MatMul>(reshape_4d, value));

//scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(scaled_atten, -1));
//auto output = context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value));
ov::OutputVector out_vec;
for (int i=0; i<8; i++) {
out_vec.push_back(output);
}
return {context.mark_node(make_list_construct(out_vec))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ OP_CONVERTER(translate_group_norm_fx);
OP_CONVERTER(translate_index_fx);
OP_CONVERTER(translate_layer_norm_fx);
OP_CONVERTER(translate_max_poolnd_fx);
OP_CONVERTER(translate_scaled_dot_product_attention_fx);
OP_CONVERTER(translate_slice_fx);
OP_CONVERTER(translate_softmax_fx);
OP_CONVERTER(translate_transpose_fx);
Expand Down Expand Up @@ -590,6 +591,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten.rsub.Scalar", op::translate_rsub},
{"aten._scaled_dot_product_flash_attention.default", op::translate_scaled_dot_product_attention_fx},
{"aten.select.int", op::translate_select},
{"aten.sigmoid.default", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
{"aten.silu.default", op::translate_1to1_match_1_inputs<opset10::Swish>},
Expand Down

0 comments on commit 2ca7f0f

Please sign in to comment.