Skip to content

Commit

Permalink
LSTMCellFusion - support transposed/not transposed weights (openvinot…
Browse files Browse the repository at this point in the history
…oolkit#21780)

* LSTMCellFusion - support transposed/not transposed weights

* add comment describing fused subgraph
  • Loading branch information
mateusztabaka authored Dec 22, 2023
1 parent bc121c0 commit 2fcaa88
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,97 @@
#include "openvino/op/split.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/tanh.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "validation_util.hpp"

/*
The following graph is fused to LSTMCell
+-----+ +-----+
| X | | H |
+--+--+ +--+--+
| |
+---+ +---+
| |
v v
+--+--+--+ +------+
| Concat | | WR |
+----+---+ +---+--+
| |
| +--------+
| |
v v
+--+--+--+ +------+
| MatMul | | Bias |
+----+---+ +--+---+
| |
| +------+
| |
v v
+--+---+--+
| Add |
+----+----+
|
|
v
+------+-------+
| Split |
+--+--+--+--+--+
| | | |
+--------------+ | | +------------------------------+
| | | |
v | +------+ +-------+ v
+------+-----+ +-----+ | | const | +------+-----+
| Activation | | | +---+---+ | Activation |
| (i_t) | | | | | (o_t) |
+------+-----+ | | +---+ +------+-----+
| v | | |
| +------+-----+ v v |
| | Activation | +-+---+-+ |
| | (c_t) | | Add | |
| +------+-----+ +---+---+ |
| | | |
| | v |
+---+ +---+ +------+-----+ |
| | | Activation | +-----+ |
v v | (f_t) | | C | |
+--+---+---+ +------------+ +-----+ |
| Multiply | | | |
+----+-----+ | +--------+ |
| | | |
| v v |
| +---+---+--+ |
| | Multiply | |
| +----+-----+ |
| | |
| | |
+---------+ +--------+ |
| | |
v v |
+-+-----+-+ |
| Add | |
| (C out) | |
+----+----+ |
| |
v |
+-----+------+ |
| Activation | |
+-----+------+ |
| |
| |
+----------+ +-------------------+
| |
v v
+--+----+--+
| Multiply |
| (H out) |
+----------+
*/

static std::string get_activation_name(const std::shared_ptr<ov::Node>& node) {
std::string name = node->get_type_name();
name[0] = std::tolower(name[0]);
Expand All @@ -37,9 +124,7 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
return pattern::has_static_shape()(output) && pattern::rank_equals(2)(output);
});
auto matmul_label = pattern::wrap_type<op::v0::MatMul>({concat_label, weights_label});
auto bias_label = pattern::any_input([](const Output<Node>& output) {
return pattern::has_static_shape()(output) && pattern::rank_equals(2)(output);
});
auto bias_label = pattern::any_input(pattern::has_static_shape());
auto bias_add_label = pattern::wrap_type<op::v1::Add>({matmul_label, bias_label});
auto axis_label = pattern::wrap_type<op::v0::Constant>();
auto split_label = pattern::wrap_type<op::v1::Split>({bias_add_label, axis_label});
Expand All @@ -62,51 +147,67 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
const auto& X = pattern_map.at(x_label);
const auto& H = pattern_map.at(h_label);
const auto& C = pattern_map.at(c_label);
const auto& WR = pattern_map.at(weights_label);
const auto& B = pattern_map.at(bias_label);
auto WR = pattern_map.at(weights_label);
auto B = pattern_map.at(bias_label);
const auto& ft_additional_bias = pattern_map.at(ft_additional_bias_label);
auto Ho = pattern_map.at(Ho_label);
auto Co = pattern_map.at(Co_label);
const auto matmul = ov::as_type_ptr<op::v0::MatMul>(pattern_map.at(matmul_label).get_node_shared_ptr());
if (!matmul)
return false;
if (matmul->get_transpose_a())
return false;

bool weights_transposed = matmul->get_transpose_b();
const auto& WR_shape = WR.get_shape();
const auto& B_shape = B.get_shape();
const auto& ft_additional_bias_shape = ft_additional_bias.get_shape();

if (WR_shape[0] % 4 != 0)
size_t input_size_plus_hidden_size = weights_transposed ? WR_shape[1] : WR_shape[0];
size_t hidden_size_times_4 = weights_transposed ? WR_shape[0] : WR_shape[1];
if (hidden_size_times_4 % 4 != 0)
return false;
if (WR_shape[0] != B_shape[1])
return false;
if (B_shape[0] != 1)
if (B_shape.size() == 2) {
if (hidden_size_times_4 != B_shape[1])
return false;
if (B_shape[0] != 1)
return false;
} else if (B_shape.size() == 1) {
if (hidden_size_times_4 != B_shape[0])
return false;
} else {
return false;
}
if (shape_size(ft_additional_bias_shape) != 1)
return false;

size_t hidden_size = WR_shape[0] / 4;
size_t hidden_size = hidden_size_times_4 / 4;

if (WR_shape[1] <= hidden_size)
if (input_size_plus_hidden_size <= hidden_size)
return false;

size_t input_size = WR_shape[1] - hidden_size;
size_t input_size = input_size_plus_hidden_size - hidden_size;

const auto& X_shape = X.get_partial_shape();
const auto& H_shape = H.get_partial_shape();
const auto& C_shape = C.get_partial_shape();

if (!H_shape[0].compatible(X_shape[0]))
if (!H_shape[0].compatible(X_shape[0])) // batch size
return false;

if (!C_shape[0].compatible(X_shape[0]))
if (!C_shape[0].compatible(X_shape[0])) // batch size
return false;

if (!X_shape[1].compatible(input_size))
return false;

if (!H_shape[1].compatible(hidden_size))
return false;

if (!C_shape[1].compatible(hidden_size))
return false;

const auto split_axis = ov::as_type_ptr<op::v0::Constant>(pattern_map.at(axis_label).get_node_shared_ptr());
int64_t split_axis_value = split_axis->cast_vector<int64_t>()[0];
if (split_axis_value != 1 && split_axis_value != -1)
return false;

NodeVector split_consumers{pattern_map.at(it_label).get_node_shared_ptr(),
pattern_map.at(ct_label).get_node_shared_ptr(),
pattern_map.at(ot_label).get_node_shared_ptr(),
Expand Down Expand Up @@ -142,22 +243,58 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
auto Co_activation = pattern_map.at(Co_activation_label).get_node_shared_ptr();
std::string h_activation_name = get_activation_name(Co_activation);

auto zero = op::v0::Constant::create(element::i32, Shape{}, {0});
auto WR_split = std::make_shared<op::v1::Split>(WR, zero /* axis */, 4);
if (!weights_transposed) {
WR = std::make_shared<op::v1::Transpose>(WR, op::v0::Constant::create(element::i32, Shape{0}, {}));
}
// Split WR to W, R and convert to the layout that OpenVino supports
//
// WR layout (icfo):
//
// W R
// +-------+---+
// i | | |
// +-------+---+
// c | | |
// +-------+---+
// f | | |
// +-------+---+
// o | | |
// +-------+---+
//
//
// W and R layouts that are supported by OpenVino (fico):
//
// W R
// +-------+ +---+
// f | | f | |
// +-------+ +---+
// i | | i | |
// +-------+ +---+
// c | | c | |
// +-------+ +---+
// o | | o | |
// +-------+ +---+
//
auto zero_axis = op::v0::Constant::create(element::i32, Shape{}, {0});
auto WR_split = std::make_shared<op::v1::Split>(WR, zero_axis, 4);
auto WR_fico = std::make_shared<op::v0::Concat>(
OutputVector{WR_split->output(2), WR_split->output(0), WR_split->output(1), WR_split->output(3)},
0);
auto one = op::v0::Constant::create(element::i32, Shape{}, {1});
auto vsplit_axis = op::v0::Constant::create(element::i32, Shape{}, {1});
auto split_lengths = op::v0::Constant::create(element::i32, Shape{2}, {input_size, hidden_size});
auto vsplit = std::make_shared<op::v1::VariadicSplit>(WR_fico, one /* axis */, split_lengths);
auto vsplit = std::make_shared<op::v1::VariadicSplit>(WR_fico, vsplit_axis, split_lengths);
Output<Node> W = vsplit->output(0);
if (auto constant = ov::util::constantfold_subgraph(W))
W = constant;
Output<Node> R = vsplit->output(1);
if (auto constant = ov::util::constantfold_subgraph(R))
R = constant;

auto B_split = std::make_shared<op::v1::Split>(std::make_shared<op::v0::Squeeze>(B, zero), zero /* axis */, 4);
if (B_shape.size() > 1)
B = std::make_shared<op::v0::Squeeze>(B, zero_axis);

// Convert B layout from icfo to fico
auto B_split = std::make_shared<op::v1::Split>(B, zero_axis, 4);
auto B_f =
std::make_shared<op::v1::Add>(B_split->output(2), std::make_shared<op::v0::Squeeze>(ft_additional_bias));

Expand All @@ -176,13 +313,18 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
B_fico,
hidden_size,
std::vector<std::string>{f_activation_name, g_activation_name, h_activation_name});

if (transformation_callback(lstm_cell)) {
return false;
}

lstm_cell->set_friendly_name(m.get_match_root()->get_friendly_name());

copy_runtime_info(
{
pattern_map.at(concat_label).get_node_shared_ptr(),
WR.get_node_shared_ptr(),
pattern_map.at(matmul_label).get_node_shared_ptr(),
matmul,
B.get_node_shared_ptr(),
pattern_map.at(bias_add_label).get_node_shared_ptr(),
pattern_map.at(split_label).get_node_shared_ptr(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,41 @@
#include "openvino/op/sigmoid.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/tanh.hpp"
#include "openvino/pass/constant_folding.hpp"

using namespace ov;

TEST_F(TransformationTestsF, LSTMCellFusion) {
using LSTMCellFusionParam = std::tuple<bool, // true if second input to matmul is transposed
int, // rank of bias (B)
int>; // split axis

class LSTMCellFusionTestSuite : public testing::WithParamInterface<LSTMCellFusionParam>, public TransformationTestsF {};

TEST_P(LSTMCellFusionTestSuite, SubgraphFusedToLSTMCell) {
const auto& param = GetParam();
bool weights_transposed = std::get<0>(param);
int B_rank = std::get<1>(param);
int split_axis_value = std::get<2>(param);
size_t input_size = 3;
size_t hidden_size = 2;

{
auto X = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, input_size});
auto H = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size});
auto C = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size});
auto concat = std::make_shared<op::v0::Concat>(OutputVector{X, H}, 1);
Shape WR_shape{4 * hidden_size, input_size + hidden_size};
Shape WR_shape = weights_transposed ? Shape{4 * hidden_size, input_size + hidden_size}
: Shape{input_size + hidden_size, 4 * hidden_size};
std::vector<float> WR_values(shape_size(WR_shape));
std::iota(WR_values.begin(), WR_values.end(), 0.0f);
auto WR = op::v0::Constant::create(element::f32, WR_shape, WR_values);
auto matmul = std::make_shared<op::v0::MatMul>(concat, WR, false, true);
Shape B_shape{1, 4 * hidden_size};
auto matmul = std::make_shared<op::v0::MatMul>(concat, WR, false, weights_transposed);
Shape B_shape = B_rank == 2 ? Shape{1, 4 * hidden_size} : Shape{4 * hidden_size};
std::vector<float> B_values(shape_size(B_shape));
std::iota(B_values.begin(), B_values.end(), 0.0f);
auto B = op::v0::Constant::create(element::f32, B_shape, B_values);
auto biasadd = std::make_shared<op::v1::Add>(matmul, B);
auto one = op::v0::Constant::create(element::i32, Shape{}, {1});
auto split = std::make_shared<op::v1::Split>(biasadd, one /* axis */, 4 /* num splits */);
auto split_axis = op::v0::Constant::create(element::i32, Shape{}, {split_axis_value});
auto split = std::make_shared<op::v1::Split>(biasadd, split_axis, 4 /* num splits */);
auto it = std::make_shared<op::v0::Sigmoid>(split->output(0));
auto ct = std::make_shared<op::v0::Tanh>(split->output(1));
auto ft = std::make_shared<op::v0::Sigmoid>(
Expand All @@ -62,28 +73,15 @@ TEST_F(TransformationTestsF, LSTMCellFusion) {
auto concat = std::make_shared<op::v0::Concat>(OutputVector{X, H}, 1);
Shape W_shape{4 * hidden_size, input_size};
Shape R_shape{4 * hidden_size, hidden_size};
std::vector<float> W_values{
20, 21, 22, 25, 26, 27, 0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17, 30, 31, 32, 35, 36, 37,
};
std::vector<float> W_values = weights_transposed
? std::vector<float>{20, 21, 22, 25, 26, 27, 0, 1, 2, 5, 6, 7,
10, 11, 12, 15, 16, 17, 30, 31, 32, 35, 36, 37}
: std::vector<float>{4, 12, 20, 5, 13, 21, 0, 8, 16, 1, 9, 17,
2, 10, 18, 3, 11, 19, 6, 14, 22, 7, 15, 23};
auto W = op::v0::Constant::create(element::f32, W_shape, W_values);
std::vector<float> R_values{
23,
24,
28,
29,
3,
4,
8,
9,
13,
14,
18,
19,
33,
34,
38,
39,
};
std::vector<float> R_values =
weights_transposed ? std::vector<float>{23, 24, 28, 29, 3, 4, 8, 9, 13, 14, 18, 19, 33, 34, 38, 39}
: std::vector<float>{28, 36, 29, 37, 24, 32, 25, 33, 26, 34, 27, 35, 30, 38, 31, 39};
auto R = op::v0::Constant::create(element::f32, R_shape, R_values);
Shape B_shape{4 * hidden_size};
std::vector<float> B_values{5, 6, 0, 1, 2, 3, 6, 7};
Expand All @@ -106,3 +104,7 @@ TEST_F(TransformationTestsF, LSTMCellFusion) {
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}

INSTANTIATE_TEST_SUITE_P(LSTMCellFusion,
LSTMCellFusionTestSuite,
testing::Combine(testing::Values(false, true), testing::Values(1, 2), testing::Values(1, -1)));

0 comments on commit 2fcaa88

Please sign in to comment.