From 2ca7f0f160ac0493f316f2a51d8f487eb95ac3d4 Mon Sep 17 00:00:00 2001
From: Cavus Mustafa <mustafa.cavus@intel.com>
Date: Tue, 10 Oct 2023 15:33:03 -0700
Subject: [PATCH] TorchFX: flash_attention support with SoftMax optimization
 (GPU)

---
 .../pytorch/torchdynamo/op_support.py         |  1 +
 src/frontends/pytorch/src/frontend.cpp        |  2 +-
 .../src/op/scaled_dot_product_attention.cpp   | 98 ++++++++++++++++++-
 src/frontends/pytorch/src/op_table.cpp        |  2 +
 4 files changed, 101 insertions(+), 2 deletions(-)

diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py
index 726f3b598bc15e..46852106a750fd 100644
--- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py
+++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py
@@ -89,6 +89,7 @@ def __init__(self):
             "torch.ops.aten.relu.default": None,
             "torch.ops.aten.relu_.default": None,
             "torch.ops.aten.rsub.Scalar": None,
+            "torch.ops.aten._scaled_dot_product_flash_attention.default": None,
             "torch.ops.aten.select.int": None,
             "torch.ops.aten.sigmoid.default": None,
             "torch.ops.aten.silu.default": None,
diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp
index 0910aa3e057e72..f82c9c1faa5dac 100644
--- a/src/frontends/pytorch/src/frontend.cpp
+++ b/src/frontends/pytorch/src/frontend.cpp
@@ -199,7 +199,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
     manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
     manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
     manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
-    manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
+    //manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
     manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
     manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
     manager.register_pass<ov::pass::ResolveNameCollisions>();
diff --git a/src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp b/src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp
index 735324405d1f11..a73147c78ff706 100644
--- a/src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp
+++ b/src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp
@@ -3,6 +3,7 @@
 //
 
 #include "openvino/frontend/pytorch/node_context.hpp"
+#include "openvino/op/util/framework_node.hpp"
 #include "openvino/op/add.hpp"
 #include "openvino/op/broadcast.hpp"
 #include "openvino/op/concat.hpp"
@@ -15,6 +16,7 @@
 #include "openvino/op/matmul.hpp"
 #include "openvino/op/multiply.hpp"
 #include "openvino/op/range.hpp"
+#include "openvino/op/reshape.hpp"
 #include "openvino/op/select.hpp"
 #include "openvino/op/shape_of.hpp"
 #include "openvino/op/softmax.hpp"
@@ -103,7 +105,101 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context)
     return {context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value))};
 };
 
+OutputVector translate_scaled_dot_product_attention_fx(const NodeContext& context) {
+    // aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
+    // dropout_p=0., bool is_causal=False)
+    num_inputs_check(context, 3, 6);
+    auto query = context.get_input(0);
+    auto key = context.get_input(1);
+    auto value = context.get_input(2);
+    auto q_shape = context.mark_node(std::make_shared<v3::ShapeOf>(query, element::i32));
+    auto k_shape = context.mark_node(std::make_shared<v3::ShapeOf>(key, element::i32));
+    auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
+    auto minus_two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-2}));
+    auto zero_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
+    auto one_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
+    auto scale = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_one, zero_i));
+    scale = context.mark_node(std::make_shared<v1::ConvertLike>(scale, query));
+    auto sqrt_scale = context.mark_node(std::make_shared<v0::Sqrt>(scale));
+    auto one_f = context.mark_node(std::make_shared<v1::ConvertLike>(one_i, sqrt_scale));
+    auto zero_f = context.mark_node(std::make_shared<v1::ConvertLike>(zero_i, sqrt_scale));
+    scale = context.mark_node(std::make_shared<v1::Divide>(one_f, sqrt_scale));
+    auto q_scaled = context.mark_node(std::make_shared<v1::Multiply>(query, scale));
+    auto k_rank = context.mark_node(std::make_shared<v3::ShapeOf>(k_shape, element::i32));
+    auto k_last_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_one));
+    auto k_next_dim = context.mark_node(std::make_shared<v1::Add>(k_rank, minus_two));
+    k_rank = context.mark_node(std::make_shared<v0::Squeeze>(k_rank, zero_i));
+    auto minus_inf =
+        context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
+    auto keep_dim_last = context.mark_node(std::make_shared<v0::Squeeze>(k_next_dim, zero_i));
+    auto k_dims_before_transpose =
+        context.mark_node(std::make_shared<v4::Range>(zero_i, keep_dim_last, one_i, element::i32));
+
+    auto transpose_dims = context.mark_node(
+        std::make_shared<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0));
+    auto k_transposed = context.mark_node(std::make_shared<v1::Transpose>(key, transpose_dims));
+    auto scaled_atten = context.mark_node(std::make_shared<v0::MatMul>(q_scaled, k_transposed));
+    minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten));
+    // two types of masks are supported. A boolean mask where a value of True indicates that the element should take
+    // part in attention. A float mask of the same type as query, key, value that is added to the attention score.
+    //auto is_causal = context.const_input<bool>(5);
+    auto is_causal = false;
+    if (is_causal || !context.input_is_none(3)) {
+        Output<Node> mask;
+        Output<Node> atten_mask;
+        if (!context.input_is_none(3)) {
+            mask = context.get_input(3);
+            if (mask.get_element_type() == element::boolean) {
+                atten_mask = context.mark_node(std::make_shared<v1::ConvertLike>(mask, scaled_atten));
+                auto inv_mask = context.mark_node(std::make_shared<v1::LogicalNot>(mask));
+                atten_mask = context.mark_node(std::make_shared<v1::Select>(inv_mask, atten_mask, minus_inf));
+            } else {
+                atten_mask = mask;
+            }
+        } else {
+            auto target_s_len = context.mark_node(std::make_shared<v8::Gather>(q_shape, minus_two, zero_i));
+            auto source_s_len = context.mark_node(std::make_shared<v8::Gather>(k_shape, minus_two, zero_i));
+            auto ssl = context.mark_node(std::make_shared<v0::Unsqueeze>(source_s_len, zero_i));
+            auto tsl = context.mark_node(std::make_shared<v0::Unsqueeze>(target_s_len, zero_i));
+            auto mask_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{tsl, ssl}, 0));
+            mask = context.mark_node(std::make_shared<v1::Broadcast>(minus_inf, mask_shape));
+            auto horizontal_range =
+                context.mark_node(std::make_shared<v4::Range>(zero_i, source_s_len, one_i, element::i32));
+            horizontal_range = context.mark_node(std::make_shared<v0::Unsqueeze>(horizontal_range, zero_i));
+            auto stop = context.mark_node(std::make_shared<v1::Add>(target_s_len, one_i));
+            auto vertical_range = context.mark_node(std::make_shared<v4::Range>(one_i, stop, one_i, element::i32));
+            vertical_range = context.mark_node(std::make_shared<v0::Unsqueeze>(vertical_range, one_i));
+            auto triu = context.mark_node(std::make_shared<v1::GreaterEqual>(horizontal_range, vertical_range));
+            atten_mask = context.mark_node(std::make_shared<v1::Select>(triu, mask, zero_f));
+        }
+        scaled_atten = context.mark_node(std::make_shared<v1::Add>(scaled_atten, atten_mask));
+    }
+    std::vector<int> softmax_shape_v(3);
+    softmax_shape_v[0] = 16;
+    softmax_shape_v[1] = scaled_atten->get_shape()[2];
+    softmax_shape_v[2] = scaled_atten->get_shape()[3];
+    std::vector<int> matmul_shape_v(4);
+    matmul_shape_v[0] = 2;
+    matmul_shape_v[1] = 8;
+    matmul_shape_v[2] = scaled_atten->get_shape()[2];
+    matmul_shape_v[3] = scaled_atten->get_shape()[3];
+    auto softmax_shape = context.mark_node(v0::Constant::create(element::i32, Shape{3}, softmax_shape_v));
+    auto reshape_3d = context.mark_node(std::make_shared<v1::Reshape>(scaled_atten, softmax_shape, false));
+    scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(reshape_3d, -1));
+    auto matmul_shape = context.mark_node(v0::Constant::create(element::i32, Shape{4}, matmul_shape_v));
+    auto reshape_4d = context.mark_node(std::make_shared<v1::Reshape>(scaled_atten, matmul_shape, true));
+    auto output = context.mark_node(std::make_shared<v0::MatMul>(reshape_4d, value));
+
+    //scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(scaled_atten, -1));
+    //auto output = context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value));
+    ov::OutputVector out_vec;
+    for (int i=0; i<8; i++) {
+        out_vec.push_back(output);
+    }
+    return {context.mark_node(make_list_construct(out_vec))};
+};
+
 }  // namespace op
 }  // namespace pytorch
 }  // namespace frontend
-}  // namespace ov
\ No newline at end of file
+}  // namespace ov
diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp
index 20c53dbe52bc9f..5a6ca488da2834 100644
--- a/src/frontends/pytorch/src/op_table.cpp
+++ b/src/frontends/pytorch/src/op_table.cpp
@@ -207,6 +207,7 @@ OP_CONVERTER(translate_group_norm_fx);
 OP_CONVERTER(translate_index_fx);
 OP_CONVERTER(translate_layer_norm_fx);
 OP_CONVERTER(translate_max_poolnd_fx);
+OP_CONVERTER(translate_scaled_dot_product_attention_fx);
 OP_CONVERTER(translate_slice_fx);
 OP_CONVERTER(translate_softmax_fx);
 OP_CONVERTER(translate_transpose_fx);
@@ -590,6 +591,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
         {"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
         {"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
         {"aten.rsub.Scalar", op::translate_rsub},
+        {"aten._scaled_dot_product_flash_attention.default", op::translate_scaled_dot_product_attention_fx},
         {"aten.select.int", op::translate_select},
         {"aten.sigmoid.default", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
         {"aten.silu.default", op::translate_1to1_match_1_inputs<opset10::Swish>},