Skip to content

Commit

Permalink
TorchFX: Initial scaled_dot_product_flash_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Oct 12, 2023
1 parent e050d5f commit 500b2c0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self):
"torch.ops.aten.relu.default": None,
"torch.ops.aten.relu_.default": None,
"torch.ops.aten.rsub.Scalar": None,
"torch.ops.aten._scaled_dot_product_flash_attention.default": None,
"torch.ops.aten.select.int": None,
"torch.ops.aten.sigmoid.default": None,
"torch.ops.aten.silu.default": None,
Expand Down
36 changes: 29 additions & 7 deletions src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
Expand All @@ -15,6 +16,7 @@
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/softmax.hpp"
Expand All @@ -31,10 +33,7 @@ namespace op {

using namespace ov::op;

OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False)
num_inputs_check(context, 6, 6);
std::shared_ptr<ov::Node> translate_scaled_dot_product_attention_common(const NodeContext& context) {
auto query = context.get_input(0);
auto key = context.get_input(1);
auto value = context.get_input(2);
Expand Down Expand Up @@ -68,7 +67,10 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context)
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten));
// two types of masks are supported. A boolean mask where a value of True indicates that the element should take
// part in attention. A float mask of the same type as query, key, value that is added to the attention score.
auto is_causal = context.const_input<bool>(5);
auto is_causal = false;
if (!context.input_is_none(5)) {
is_causal = context.const_input<bool>(5);
}
if (is_causal || !context.input_is_none(3)) {
Output<Node> mask;
Output<Node> atten_mask;
Expand Down Expand Up @@ -100,10 +102,30 @@ OutputVector translate_scaled_dot_product_attention(const NodeContext& context)
scaled_atten = context.mark_node(std::make_shared<v1::Add>(scaled_atten, atten_mask));
}
scaled_atten = context.mark_node(std::make_shared<v8::Softmax>(scaled_atten, -1));
return {context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value))};
return context.mark_node(std::make_shared<v0::MatMul>(scaled_atten, value));
};

OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False)
num_inputs_check(context, 6, 6);
return {translate_scaled_dot_product_attention_common(context)};
};

OutputVector translate_scaled_dot_product_attention_fx(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False)
num_inputs_check(context, 3, 6);
auto output = translate_scaled_dot_product_attention_common(context);
// TODO: scaled_dot_product_flash_attention has 9 outputs but fort most cases only
// the first input is used. Rest of the outputs should be returned properly as
// needed.
ov::OutputVector out_vec;
out_vec.push_back(output);
return {context.mark_node(make_list_construct(out_vec))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ OP_CONVERTER(translate_group_norm_fx);
OP_CONVERTER(translate_index_fx);
OP_CONVERTER(translate_layer_norm_fx);
OP_CONVERTER(translate_max_poolnd_fx);
OP_CONVERTER(translate_scaled_dot_product_attention_fx);
OP_CONVERTER(translate_slice_fx);
OP_CONVERTER(translate_softmax_fx);
OP_CONVERTER(translate_transpose_fx);
Expand Down Expand Up @@ -602,6 +603,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten.rsub.Scalar", op::translate_rsub},
{"aten._scaled_dot_product_flash_attention.default", op::translate_scaled_dot_product_attention_fx},
{"aten.select.int", op::translate_select},
{"aten.sigmoid.default", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
{"aten.silu.default", op::translate_1to1_match_1_inputs<opset10::Swish>},
Expand Down

0 comments on commit 500b2c0

Please sign in to comment.