Skip to content

Commit

Permalink
added RNNCell node support
Browse files Browse the repository at this point in the history
  • Loading branch information
ndemashov committed Mar 28, 2022
1 parent f10534b commit 822dd07
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 79 deletions.
13 changes: 11 additions & 2 deletions src/common/low_precision_transformations/src/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ LSTMTransformation::LSTMTransformation(const Params& params) : LayerTransformati
const auto gru_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type<ngraph::opset4::GRUCell>(
{dequantization_multiply_without_subtract_X, dequantization_multiply_without_subtract_H, fq_W, fq_R, B});

const auto rnn_cell = ngraph::pattern::wrap_type<ngraph::opset4::RNNCell>({fq_X, fq_H, fq_W, fq_R, B});
const auto rnn_cell_with_dequantizations = ngraph::pattern::wrap_type<ngraph::opset4::RNNCell>(
{dequantization_multiply_X, dequantization_multiply_X, fq_W, fq_R, B});
const auto rnn_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type<ngraph::opset4::RNNCell>(
{dequantization_multiply_without_subtract_X, dequantization_multiply_without_subtract_H, fq_W, fq_R, B});

ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
Expand All @@ -111,7 +117,10 @@ LSTMTransformation::LSTMTransformation(const Params& params) : LayerTransformati
lstm_sequence_with_dequantizations_without_subtract,
gru_cell,
gru_cell_with_dequantizations,
gru_cell_with_dequantizations_without_subtract}),
gru_cell_with_dequantizations_without_subtract,
rnn_cell,
rnn_cell_with_dequantizations,
rnn_cell_with_dequantizations_without_subtract}),
"LSTM");
this->register_matcher(m, callback);
}
Expand Down Expand Up @@ -177,7 +186,7 @@ bool LSTMTransformation::isPrecisionPreserved(std::shared_ptr<Node>) const noexc
return true;
}

void LSTMTransformation::propagateSkipCleanupAttribute(std::shared_ptr<Node> multiply){
void LSTMTransformation::propagateSkipCleanupAttribute(std::shared_ptr<Node> multiply) {
SkipCleanupAttribute::create(multiply, true);
auto multiply_parent = multiply->get_input_node_shared_ptr(0);
SkipCleanupAttribute::create(multiply_parent, true);
Expand Down
39 changes: 10 additions & 29 deletions src/plugins/intel_cpu/src/graph_dumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@

#include "graph_dumper.h"

#include "utils/debug_capabilities.h"
#include <ie_ngraph_utils.hpp>
#include "exec_graph_info.hpp"
#include "ie_common.h"
#include "ie_ngraph_utils.hpp"
#include "mkldnn_debug.h"
#include <nodes/tensoriterator.h>
#include "cpu_types.h"
#include "utils/debug_capabilities.h"

#include <ngraph/variant.hpp>
#include "ngraph/ngraph.hpp"
#include <ngraph/pass/manager.hpp>
#include "ngraph/op/tensor_iterator.hpp"
#include <openvino/pass/serialize.hpp>

#include <vector>
Expand Down Expand Up @@ -43,19 +39,7 @@ std::map<std::string, std::string> extract_node_metadata(const NodePtr &node) {
// Path to print actual name for extension layers
serialization_info[ExecGraphInfoSerialization::LAYER_TYPE] = node->getTypeStr();
} else {
std::string layerTypeStr;

auto layerType = node->getType();

/* replace CPU proprietary input/output types with the ones which serializer can process */
if (layerType == Type::Input)
layerTypeStr = "Parameter";
else if (layerType == Type::Output)
layerTypeStr = "Result";
else
layerTypeStr = NameFromType(node->getType());

serialization_info[ExecGraphInfoSerialization::LAYER_TYPE] = layerTypeStr;
serialization_info[ExecGraphInfoSerialization::LAYER_TYPE] = NameFromType(node->getType());
}

// Original layers
Expand Down Expand Up @@ -189,16 +173,12 @@ std::shared_ptr<ngraph::Function> dump_graph_as_ie_ngraph_net(const Graph &graph
results.emplace_back(std::make_shared<ngraph::op::Result>(get_inputs(node).back()));
return_node = results.back();
} else {
if (node->getAlgorithm() == Algorithm::TensorIteratorCommon) {
return_node = create_ngraph_ti_node(node);
} else {
return_node = std::make_shared<ExecGraphInfoSerialization::ExecutionNode>(
get_inputs(node), node->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size());

for (size_t port = 0; port < return_node->get_output_size(); ++port) {
auto& desc = node->getChildEdgeAt(port)->getMemory().getDesc();
return_node->set_output_type(port, details::convertPrecision(desc.getPrecision()), desc.getShape().toPartialShape());
}
return_node = std::make_shared<ExecGraphInfoSerialization::ExecutionNode>(
get_inputs(node), node->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size());

for (size_t port = 0; port < return_node->get_output_size(); ++port) {
auto& desc = node->getChildEdgeAt(port)->getMemory().getDesc();
return_node->set_output_type(port, details::convertPrecision(desc.getPrecision()), desc.getShape().toPartialShape());
}
}

Expand All @@ -213,6 +193,7 @@ std::shared_ptr<ngraph::Function> dump_graph_as_ie_ngraph_net(const Graph &graph
return return_node;
};

ngraph::NodeVector nodes;
nodes.reserve(graph.graphNodes.size());
for (auto &node : graph.graphNodes) { // important: graph.graphNodes are in topological order
nodes.emplace_back(create_ngraph_node(node));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,19 @@ class LSTMTransformationTestValues {
public:
LSTMTransformationTestValues() = default;
LSTMTransformationTestValues(const TestTransformationParams& params,
const bool bias,
const LSTMFunction::RNNType type,
const LSTMTransformationValues& actual,
const LSTMTransformationValues& result,
const bool addNotPrecisionPreservedOperation = false,
const bool checkIntervalsAlignmentAttributes = true)
: params(params),
bias(bias),
type(type),
actual(actual),
result(result),
addNotPrecisionPreservedOperation(addNotPrecisionPreservedOperation),
checkIntervalsAlignmentAttributes(checkIntervalsAlignmentAttributes) {}

TestTransformationParams params;
bool bias;
LSTMFunction::RNNType type;
LSTMTransformationValues actual;
LSTMTransformationValues result;
Expand All @@ -88,7 +85,7 @@ class LSTMTransformationTestValues {
};

inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationTestValues& values) {
return out << "_" << values.bias << "_" << values.actual << "_" << values.result;
return out << "_" << values.actual << "_" << values.result;
}

typedef std::tuple<ngraph::element::Type, std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>, LSTMTransformationTestValues>
Expand All @@ -106,7 +103,6 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam
activations_shapes,
weights_shapes,
testValues.type,
testValues.bias,
{
testValues.actual.fakeQuantize_X,
testValues.actual.fakeQuantize_H,
Expand Down Expand Up @@ -162,7 +158,6 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam
activations_shapes,
weights_shapes,
testValues.type,
testValues.bias,
{
testValues.result.fakeQuantize_X,
testValues.result.fakeQuantize_H,
Expand Down Expand Up @@ -195,8 +190,6 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam

std::ostringstream result;
result << LayerTransformation::getTestCaseNameByParams(precision, activations_shapes[0], testValues.params)
<< "_"
<< (testValues.bias ? "with_bias_" : "without_bias_")
<< "_" << testValues.actual << "_" << testValues.result << "_";
return result.str();
}
Expand Down Expand Up @@ -235,7 +228,6 @@ const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{512, 16}, {512
const std::vector<LSTMTransformationTestValues> testValues = {
// LSTM Cell
{LayerTransformation::createParamsU8I8(),
false,
LSTMFunction::RNNType::LSTMCell,
{
// X
Expand Down Expand Up @@ -318,7 +310,6 @@ const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 512, 16}, {
const std::vector<LSTMTransformationTestValues> testValues = {
// LSTM Sequence
{LayerTransformation::createParamsU8I8(),
false,
LSTMFunction::RNNType::LSTMSequence,
{
// X
Expand Down Expand Up @@ -401,7 +392,6 @@ const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{9, 3}, {9, 3},
const std::vector<LSTMTransformationTestValues> testValues = {
// GRU
{LayerTransformation::createParamsU8I8(),
false,
LSTMFunction::RNNType::GRU,
{
// X
Expand Down Expand Up @@ -454,7 +444,7 @@ const std::vector<LSTMTransformationTestValues> testValues = {
{},
{0.01f}
},
// R
// R
{},
{},
{
Expand All @@ -475,4 +465,86 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(testValues)),
LSTMTransformation::getTestCaseName);
} // namespace testValues3

namespace testValues4 {
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{2, 3}, {2, 3}, {}}};

const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{3, 3}, {3, 3}, {9}}};

const std::vector<LSTMTransformationTestValues> testValues = {
// RNNCell
{LayerTransformation::createParamsU8I8(),
LSTMFunction::RNNType::RNNCell,
{
// X
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{ngraph::element::u8},
{
{element::f32},
{},
{0.01f},
},
// H
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{ngraph::element::u8},
{
{element::f32},
{},
{0.01f},
},
// W
{255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{},
{{}, {}, {}},
// R
{255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{},
{{}, {}, {}},
},
{
// X
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{ngraph::element::u8},
{
{element::f32},
{},
{0.01f},
},
// H
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{ngraph::element::u8},
{
{element::f32},
{},
{0.01f},
},
// W
{},
{},
{
{element::f32},
{},
{0.01f}
},
// R
{},
{},
{
{element::f32},
{},
{0.01f}
},
}
}
};
INSTANTIATE_TEST_SUITE_P(
smoke_LPT,
LSTMTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(activations_shapes),
::testing::ValuesIn(weights_shapes),
::testing::ValuesIn(testValues)),
LSTMTransformation::getTestCaseName);
} // namespace testValues4
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ namespace subgraph {

class LSTMFunction {
public:
enum class RNNType { LSTMCell, LSTMSequence, GRU };
enum class RNNType { LSTMCell, LSTMSequence, GRU, RNNCell };

static std::shared_ptr<ngraph::Function> get(
const ngraph::element::Type inputPrecision,
const std::vector<ngraph::PartialShape>& inputActivationsShapes,
const std::vector<ngraph::Shape>& inputWeightsShapes,
const RNNType type,
const bool bias,
const std::vector<FakeQuantizeOnDataWithConstant>& fqOnDatas,
const std::vector<DequantizationOperations::Convert>& converts,
const std::vector<DequantizationOperations>& dequantizations,
Expand Down
Loading

0 comments on commit 822dd07

Please sign in to comment.