Skip to content

Commit

Permalink
[CPU][BF16] Concat layer properly handle mixed precision input (openv…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexey-varyzgin authored and openvino-dev-samples committed Nov 24, 2021
1 parent 664a31f commit f8dec44
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
6 changes: 5 additions & 1 deletion inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,11 @@ void MKLDNNGraph::EnforceBF16() {
if (node->getType() != Input && node->getType() != Output) {
for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) {
const auto &parent = node->getParentEdgesAtPort(i)[0]->getParent();
if (!(parent->getType() == Input && parent->isConstant()) && // exclude skipNodes after Constant Inputs
/* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing.
* Precision conversion to BF16 does automatically, if convolution follows up after Constant Inputs
* and if activation is BF16 */
if (!(parent->getType() == Input && parent->isConstant() &&
node->getType() != Concatenation) && // Concatenation node is exception because it doesn't change an accuracy for BF16 activation
!(parent->getType() == Input && node->getType() == Eltwise) && // exclude Eltwise after Input since it supports conversion to BF16
node->getOriginalInputPrecisionAtPort(i) == Precision::FP32)
node->setOriginalInputPrecisionAtPort(i, Precision::BF16);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,74 @@ using namespace InferenceEngine;
namespace SubgraphTestsDefinitions {
// Subgraph:
/*
* Parameter Constant
* Parameter Constant[FP32/BF16]
* \ /
* \ /
* Transpose
* Constant /
* \ /
* Transpose[FP32/BF16]
* Constant[FP32] /
* \ X No Reorder
* \ /
* Concat (inPlace)
* Concat (inPlace)[FP32/BF16]
* |
* Convolution [FP32/BF16]
* |
* Result
* Result[FP32/BF16]
*/

class ConcatConstantInPlaceTest : virtual public LayerTestsUtils::LayerTestsCommon {
class ConcatConstantInPlaceTest : public testing::WithParamInterface<InferenceEngine::Precision>, virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<InferenceEngine::Precision> obj) {
std::ostringstream result;
result << "ConcatConstantInPlaceTest" << obj.param.name();
return result.str();
}

void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
inPrc = outPrc = Precision::FP32;
const std::vector<size_t> inputShape = {1, 384, 196};
auto inputParams = ngraph::builder::makeParams(ngraph::element::f32, {inputShape, inputShape});
if (Precision::BF16 == (inPrc = outPrc = this->GetParam()))
configuration.insert({ PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES });
else
configuration.insert({ PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::NO });

const std::vector<size_t> inputShape = {1, 3, 3, 11};
auto inputParams = ngraph::builder::makeParams(ngraph::element::f32, {inputShape});

auto transposeOrder = ngraph::opset8::Constant::create(ngraph::element::i32, {3}, {0, 2, 1});
auto transposeOrder = ngraph::opset8::Constant::create(ngraph::element::i32, {4}, {0, 3, 2, 1});
auto transpose = std::make_shared<ngraph::opset8::Transpose>(inputParams[0], transposeOrder);

auto concatConstantInput = ngraph::opset8::Constant::create(ngraph::element::f32, {1, 1, 384}, {10.0f});
auto concatConstantInput = ngraph::opset8::Constant::create(ngraph::element::f32, {1, 1, 3, 3}, {10.0f});
auto concat = ngraph::builder::makeConcat({concatConstantInput, transpose}, 1);

ngraph::ResultVector results{std::make_shared<ngraph::opset8::Result>(concat)};
// convolution
std::vector<float> weightValuesFP32(12);
ngraph::Shape convFilterShape = { 1, 12, 1, 1 };
// weightValuesFP32.resize(12);
FuncTestUtils::fillInputsBySinValues(weightValuesFP32.data(), weightValuesFP32.size());
auto weightsNode = std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32, convFilterShape, weightValuesFP32);
std::shared_ptr<ngraph::Node> conv = std::make_shared<ngraph::opset1::Convolution>(
concat, weightsNode, ngraph::Strides({ 1, 1 }), ngraph::CoordinateDiff({ 0, 0 }),
ngraph::CoordinateDiff({ 0, 0 }), ngraph::Strides({ 1, 1 }), ngraph::op::PadType::EXPLICIT);
conv->set_friendly_name("CONV");

ngraph::ResultVector results{std::make_shared<ngraph::opset8::Result>(conv)};
function = std::make_shared<ngraph::Function>(results, inputParams, "ConcatConstantInPlace");
}
};

namespace {
TEST_F(ConcatConstantInPlaceTest, smoke_ConcatConstantInPlaceTest_CPU) {
TEST_P(ConcatConstantInPlaceTest, smoke_ConcatConstantInPlaceTest_CPU) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()

Run();
if (this->GetParam() == Precision::BF16)
CheckNodeOfTypeCount(executableNetwork, "Reorder", 4);
else
CheckNodeOfTypeCount(executableNetwork, "Reorder", 3);
}

INSTANTIATE_TEST_SUITE_P(smoke_ConcatConstantInPlaceTest_CPU, ConcatConstantInPlaceTest,
testing::Values(Precision::FP32, Precision::BF16),
ConcatConstantInPlaceTest::getTestCaseName);

} // namespace
} // namespace SubgraphTestsDefinitions
} // namespace SubgraphTestsDefinitions

0 comments on commit f8dec44

Please sign in to comment.