Skip to content

Commit

Permalink
[CPU] Add and correct tests for int8 LSTM (#17447)
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky authored Jun 14, 2023
1 parent ca0d409 commit d66e322
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 154 deletions.
36 changes: 25 additions & 11 deletions src/plugins/intel_cpu/src/nodes/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ void RNN::configurePortDataTypes() {

if (one_of(memory::data_type::bf16, inDataTypes[xIdx], inDataTypes[hIdx]))
inDataTypes[xIdx] = outDataTypes[yIdx] = outDataTypes[hoIdx] = inDataTypes[hIdx] = memory::data_type::bf16; // required by oneDNN.

if (outDataTypes[yIdx] == memory::data_type::bf16 && one_of(inDataTypes[xIdx], memory::data_type::s8, memory::data_type::u8))
outDataTypes[yIdx] = memory::data_type::f32; // oneDNN does not support bf16 output precision for quantized rnn primitive yet
}

void RNN::getSupportedDescriptors() {
Expand Down Expand Up @@ -870,7 +873,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
case dnnl::algorithm::vanilla_gru:
return dnnl::gru_forward::primitive_desc(
engine,
Expand All @@ -882,7 +886,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
case dnnl::algorithm::lbr_gru:
return dnnl::lbr_gru_forward::primitive_desc(
engine,
Expand All @@ -894,7 +899,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
case dnnl::algorithm::vanilla_lstm:
return dnnl::lstm_forward::primitive_desc(
engine,
Expand All @@ -908,7 +914,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
outDataDescs[RNN::InOutKind::CellState]->getDnnlDesc()); // Out State C
outDataDescs[RNN::InOutKind::CellState]->getDnnlDesc(), // Out State C
attr);
case dnnl::algorithm::vanilla_augru:
return dnnl::augru_forward::primitive_desc(
engine,
Expand All @@ -921,7 +928,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
case dnnl::algorithm::lbr_augru:
return dnnl::lbr_augru_forward::primitive_desc(
engine,
Expand All @@ -934,7 +942,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
default:
IE_THROW() << "RNN. Unknown cell type";
}
Expand Down Expand Up @@ -979,19 +988,19 @@ void RNN::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,

// Fill supported config
NodeConfig config;
for (size_t i = 0; i < inputDesc.size(); i++) {
for (const auto &desc : inputDesc) {
PortConfig dataConfig;
dataConfig.inPlace(-1);
dataConfig.constant(false);
dataConfig.setMemDesc(inputDesc[i]);
dataConfig.setMemDesc(desc);
config.inConfs.push_back(dataConfig);
}

for (size_t i = 0; i < outputDesc.size(); i++) {
for (const auto &desc : outputDesc) {
PortConfig dataConfig;
dataConfig.inPlace(-1);
dataConfig.constant(false);
dataConfig.setMemDesc(outputDesc[i]);
dataConfig.setMemDesc(desc);
config.outConfs.push_back(dataConfig);
}

Expand All @@ -1003,7 +1012,12 @@ Node::AttrPtr RNN::initPrimitiveAttr() {
attr->set_scratchpad_mode(dnnl::scratchpad_mode::user);

if (one_of(inDataTypes[xIdx], memory::data_type::u8, memory::data_type::s8)) {
const int weightsScaleMask = 0;
const int weightsScaleMask = 0
+ (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo`
+ (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo`

DEBUG_LOG(getName(), ": inputScale: ", inputScale, ", inputShift: ", inputShift,
", weightsScaleMask: ", weightsScaleMask, ", weightsScales[0]: ", weightsScales[0]);

attr->set_rnn_weights_qparams(weightsScaleMask, weightsScales);
attr->set_rnn_data_qparams(inputScale, inputShift);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
#include "convert_fq_rnn_to_quantized_rnn.hpp"

#include <algorithm>
#include <ngraph/opsets/opset9.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
Expand All @@ -14,6 +15,7 @@

#include "ie_common.h"
#include "itt.hpp"
#include "openvino/core/type/element_type.hpp"

#include <stdexcept>
#include <vector>
Expand Down Expand Up @@ -164,11 +166,15 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
if (*input_scale_ptr == 0.f)
OPENVINO_THROW("Cannot handle zero input scale");

const float input_scale = 1 / *input_scale_ptr;
const std::vector<float> weights_scales = weights_scale_constant->get_vector<float>();
const float input_scale = 1 / *input_scale_ptr;
std::vector<float> weights_scales = weights_scale_constant->get_vector<float>();

// transform dequantization scales into quantization ones
std::transform(weights_scales.begin(), weights_scales.end(), weights_scales.begin(), [](float& scale) { return 1 / scale; });

auto& runtime_info = rnn_quantized->get_rt_info();

// use runtime information to store input and weight scales
runtime_info["inputScale"] = input_scale;
runtime_info["weightsScales"] = weights_scales;

Expand All @@ -178,7 +184,6 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
if (input_shift_it != pattern_map.end()) {
const auto input_shift_constant = std::dynamic_pointer_cast<ngraph::opset9::Constant>(input_shift_it->second.get_node_shared_ptr());
const float* input_shift_ptr = input_shift_constant->get_data_ptr<float>();

runtime_info["inputShift"] = *input_shift_ptr;
}

Expand Down Expand Up @@ -207,6 +212,7 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
}

auto new_multiply = multiply->clone_with_new_inputs({multiply_in, multiply->input_value(1)});
new_multiply->set_friendly_name(rnn_quantized->get_friendly_name() + ".1");

for (auto output : H_outputs) {
output.replace_source_output(new_multiply);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,63 @@
* with FQ operations on the inputs and forms a new TypeRelaxed operation
* with quantization parameters as runtime parameters of the operation.
* @todo add ascii graph examples
*
* Before:
*
* +-------+ +-------+ +-------+ +-------+ +-------+ +-------+
* | X | | H | | C | | W | | R | | B |
* | | | | | | | | | | | |
* | u8/i8 | | u8/i8 | | f32 | | i8 | | i8 | | f32 |
* +---+---+ +---+---+ +---+---+ +---+---+ +---+---+ +---+---+
* | | | | | |
* +---v---+ +---v---+ | +---v---+ +---v---+ |
* | | | | | | | | | |
* | deq | | deq | | | deq | | deq | |
* | | | | | | | | | |
* +---+---+ +---+---+ | +---+---+ +---+---+ |
* | | | | | |
* | | | | | |
* +---v-----------v-----------v----------v----------v----------v---+
* | |
* | LSTMSequence / GRUSequence (f32) |
* | |
* +---------------+-----------+----------+-------------------------+
* | | |
* |Y f32 |Ho f32 |Co f32
* | | |
* | | |
* | | |
* v v v
*
* v
*
*
* After:
*
* +-------+ +-------+ +-------+ +-------+ +-------+ +-------+
* | X | | H | | C | | W | | R | | B |
* | | | | | | | | | | | |
* | u8/i8 | | u8/i8 | | f32 | | i8 | | i8 | | f32 |
* +---+---+ +---+---+ +---+---+ +---+---+ +---+---+ +---+---+
* | | | | | |
* | | | | | |
* +---v-----------v-----------v----------v----------v----------v---+
* | TypeRelaxed rt_info[inputScales] |
* | |
* | LSTMSequence / GRUSequence (u8/i8) rt_into[weightsScales] |
* +---------------+-----------+----------+-------------------------+
* | | |
* |Y f32 |Ho u8/i8 |Co f32
* | | |
* | +---v---+ |
* | | | |
* | | deq | |
* | | | |
* | +---+---+ |
* | | |
* | | |
* | | |
* v v v
*/

namespace ov {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
#include "low_precision/convolution_backprop_data.hpp"
#include "low_precision/group_convolution.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"
#include "low_precision/recurrent_cell.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/bias_attribute.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
Expand Down Expand Up @@ -504,10 +505,10 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
{{1}, {ov::element::i8}}
}),
PrecisionsRestriction::create<ov::opset5::LSTMSequence>({
{{0, 1}, {ov::element::u8, ov::element::i8}},
{{0, 1}, {ov::element::u8}}
}),
PrecisionsRestriction::create<ov::opset6::GRUSequence>({
{{0, 1}, {ov::element::u8, ov::element::i8}},
{{0, 1}, {ov::element::u8}}
}),
});

Expand Down Expand Up @@ -548,6 +549,7 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
return ov::marked_as_bias(node);
});

CPU_DISABLE_PASS_ARM(lptManager, ngraph::pass::low_precision::RecurrentCellTransformation);
CPU_DISABLE_PASS_COMMON(lptManager, ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation);

lptManager.run_passes(model);
Expand Down Expand Up @@ -609,7 +611,7 @@ void Transformations::PostLpt() {
}

// Execute before snippets. Otherwise FQ will be converted to Subgraph
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ConvertFqRnnToQuantizedRnn);
CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn);
postLPTPassManager.run_passes(model);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> tras
namespace testValues1 {

const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> params = {
// LSTMCell
// LSTMSequence
{
// X
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
Expand All @@ -47,8 +47,8 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
{},
{{}, {}, {}},
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell,
"RNNCell",
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence,
"RNNSeq",
"U8"
},
// asymmetrical FQ on weights
Expand Down Expand Up @@ -77,14 +77,14 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
{},
{{}, {}, {}},
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell,
"RNNCell",
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence,
"RNNSeq",
"FP32"
}
};

const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 16}, {1, 128}, {1, 128}}};
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{512, 16}, {512, 128}, {512}}};
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 2, 16}, {1, 1, 128}, {1, 1, 128}}};
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}};

INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
::testing::Combine(
Expand Down Expand Up @@ -126,8 +126,8 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
{},
{{}, {}, {}},
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU,
"RNNCell",
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRUSequence,
"RNNSeq",
"U8"
},
// asymmetrical FQ on weights
Expand Down Expand Up @@ -156,14 +156,14 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
{},
{{}, {}, {}},
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU,
"RNNCell",
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRUSequence,
"RNNSeq",
"FP32"
}
};

const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{2, 3}, {2, 3}, {}}};
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{9, 3}, {9, 3}, {9}}};
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 1, 3}, {1, 1, 3}, {}}};
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 9, 3}, {1, 9, 3}, {1, 9}}};

INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
::testing::Combine(
Expand Down
Loading

0 comments on commit d66e322

Please sign in to comment.