From 2ca7f0f160ac0493f316f2a51d8f487eb95ac3d4 Mon Sep 17 00:00:00 2001 From: Cavus Mustafa <mustafa.cavus@intel.com> Date: Tue, 10 Oct 2023 15:33:03 -0700 Subject: [PATCH] TorchFX: flash_attention support with SoftMax optimization (GPU) --- .../pytorch/torchdynamo/op_support.py | 1 + src/frontends/pytorch/src/frontend.cpp | 2 +- .../src/op/scaled_dot_product_attention.cpp | 98 ++++++++++++++++++- src/frontends/pytorch/src/op_table.cpp | 2 + 4 files changed, 101 insertions(+), 2 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 726f3b598bc15e..46852106a750fd 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -89,6 +89,7 @@ def __init__(self): "torch.ops.aten.relu.default": None, "torch.ops.aten.relu_.default": None, "torch.ops.aten.rsub.Scalar": None, + "torch.ops.aten._scaled_dot_product_flash_attention.default": None, "torch.ops.aten.select.int": None, "torch.ops.aten.sigmoid.default": None, "torch.ops.aten.silu.default": None, diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 0910aa3e057e72..f82c9c1faa5dac 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -199,7 +199,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const { manager.register_pass<ov::frontend::pytorch::pass::DictResolver>(); manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>(); manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>(); - manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>(); + //manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>(); manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>(); manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>(); manager.register_pass<ov::pass::ResolveNameCollisions>(); diff --git a/src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp b/src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp index 735324405d1f11..a73147c78ff706 100644 --- a/src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp +++ b/src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp @@ -3,6 +3,7 @@ // #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/util/framework_node.hpp" #include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" @@ -15,6 +16,7 @@ #include "openvino/op/matmul.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" #include "openvino/op/select.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/softmax.hpp" @@ -103,7 +105,101 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context) return {context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value))}; }; +OutputVector translate_scaled_dot_product_attention_fx(const NodeContext& context) { + // aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float + // dropout_p=0., bool is_causal=False) + num_inputs_check(context, 3, 6); + auto query = context.get_input(0); + auto key = context.get_input(1); + auto value = context.get_input(2); + auto q_shape = context.mark_node(std::make_shared<v3::ShapeOf>(query, element::i32)); + auto k_shape = context.mark_node(std::make_shared<v3::ShapeOf>(key, element::i32)); + auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + auto minus_two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-2})); + auto zero_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto one_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto scale = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_one, zero_i)); + scale = context.mark_node(std::make_shared<v1::ConvertLike>(scale, query)); + auto sqrt_scale = context.mark_node(std::make_shared<v0::Sqrt>(scale)); + auto one_f = context.mark_node(std::make_shared<v1::ConvertLike>(one_i, sqrt_scale)); + auto zero_f = context.mark_node(std::make_shared<v1::ConvertLike>(zero_i, sqrt_scale)); + scale = context.mark_node(std::make_shared<v1::Divide>(one_f, sqrt_scale)); + auto q_scaled = context.mark_node(std::make_shared<v1::Multiply>(query, scale)); + auto k_rank = context.mark_node(std::make_shared<v3::ShapeOf>(k_shape, element::i32)); + auto k_last_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_one)); + auto k_next_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_two)); + k_rank = context.mark_node(std::make_shared<v0::Squeeze>(k_rank, zero_i)); + auto minus_inf = + context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()})); + auto keep_dim_last = context.mark_node(std::make_shared<v0::Squeeze>(k_next_dim, zero_i)); + auto k_dims_before_transpose = + context.mark_node(std::make_shared<v4::Range>(zero_i, keep_dim_last, one_i, element::i32)); + + auto transpose_dims = context.mark_node( + std::make_shared<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0)); + auto k_transposed = context.mark_node(std::make_shared<v1::Transpose>(key, transpose_dims)); + auto scaled_atten = context.mark_node(std::make_shared<v0::MatMul>(q_scaled, k_transposed)); + minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten)); + // two types of masks are supported. A boolean mask where a value of True indicates that the element should take + // part in attention. A float mask of the same type as query, key, value that is added to the attention score. + //auto is_causal = context.const_input<bool>(5); + auto is_causal = false; + if (is_causal || !context.input_is_none(3)) { + Output<Node> mask; + Output<Node> atten_mask; + if (!context.input_is_none(3)) { + mask = context.get_input(3); + if (mask.get_element_type() == element::boolean) { + atten_mask = context.mark_node(std::make_shared<v1::ConvertLike>(mask, scaled_atten)); + auto inv_mask = context.mark_node(std::make_shared<v1::LogicalNot>(mask)); + atten_mask = context.mark_node(std::make_shared<v1::Select>(inv_mask, atten_mask, minus_inf)); + } else { + atten_mask = mask; + } + } else { + auto target_s_len = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_two, zero_i)); + auto source_s_len = context.mark_node(std::make_shared<v8::Gather>(k_shape, minus_two, zero_i)); + auto ssl = context.mark_node(std::make_shared<v0::Unsqueeze>(source_s_len, zero_i)); + auto tsl = context.mark_node(std::make_shared<v0::Unsqueeze>(target_s_len, zero_i)); + auto mask_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{tsl, ssl}, 0)); + mask = context.mark_node(std::make_shared<v1::Broadcast>(minus_inf, mask_shape)); + auto horizontal_range = + context.mark_node(std::make_shared<v4::Range>(zero_i, source_s_len, one_i, element::i32)); + horizontal_range = context.mark_node(std::make_shared<v0::Unsqueeze>(horizontal_range, zero_i)); + auto stop = context.mark_node(std::make_shared<v1::Add>(target_s_len, one_i)); + auto vertical_range = context.mark_node(std::make_shared<v4::Range>(one_i, stop, one_i, element::i32)); + vertical_range = context.mark_node(std::make_shared<v0::Unsqueeze>(vertical_range, one_i)); + auto triu = context.mark_node(std::make_shared<v1::GreaterEqual>(horizontal_range, vertical_range)); + atten_mask = context.mark_node(std::make_shared<v1::Select>(triu, mask, zero_f)); + } + scaled_atten = context.mark_node(std::make_shared<v1::Add>(scaled_atten, atten_mask)); + } + std::vector<int> softmax_shape_v(3); + softmax_shape_v[0] = 16; + softmax_shape_v[1] = scaled_atten->get_shape()[2]; + softmax_shape_v[2] = scaled_atten->get_shape()[3]; + std::vector<int> matmul_shape_v(4); + matmul_shape_v[0] = 2; + matmul_shape_v[1] = 8; + matmul_shape_v[2] = scaled_atten->get_shape()[2]; + matmul_shape_v[3] = scaled_atten->get_shape()[3]; + auto softmax_shape = context.mark_node(v0::Constant::create(element::i32, Shape{3}, softmax_shape_v)); + auto reshape_3d = context.mark_node(std::make_shared<v1::Reshape>(scaled_atten, softmax_shape, false)); + scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(reshape_3d, -1)); + auto matmul_shape = context.mark_node(v0::Constant::create(element::i32, Shape{4}, matmul_shape_v)); + auto reshape_4d = context.mark_node(std::make_shared<v1::Reshape>(scaled_atten, matmul_shape, true)); + auto output = context.mark_node(std::make_shared<v0::MatMul>(reshape_4d, value)); + + //scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(scaled_atten, -1)); + //auto output = context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value)); + ov::OutputVector out_vec; + for (int i=0; i<8; i++) { + out_vec.push_back(output); + } + return {context.mark_node(make_list_construct(out_vec))}; +}; + } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 20c53dbe52bc9f..5a6ca488da2834 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -207,6 +207,7 @@ OP_CONVERTER(translate_group_norm_fx); OP_CONVERTER(translate_index_fx); OP_CONVERTER(translate_layer_norm_fx); OP_CONVERTER(translate_max_poolnd_fx); +OP_CONVERTER(translate_scaled_dot_product_attention_fx); OP_CONVERTER(translate_slice_fx); OP_CONVERTER(translate_softmax_fx); OP_CONVERTER(translate_transpose_fx); @@ -590,6 +591,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() { {"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>}, {"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>}, {"aten.rsub.Scalar", op::translate_rsub}, + {"aten._scaled_dot_product_flash_attention.default", op::translate_scaled_dot_product_attention_fx}, {"aten.select.int", op::translate_select}, {"aten.sigmoid.default", op::translate_1to1_match_1_inputs<opset10::Sigmoid>}, {"aten.silu.default", op::translate_1to1_match_1_inputs<opset10::Swish>},