Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IE][VPU]: Enables Extract Dynamic Batch Transformation #3715

Merged
merged 9 commits into from
Jan 13, 2021
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>

namespace vpu {

enum class SliceMode {
Slice,
Unchanged
};

class SliceConfiguration {
public:
SliceConfiguration() = default;
SliceConfiguration(std::vector<SliceMode> inputs, std::vector<SliceMode> outputs);

bool isSliceSupported() const;
const std::vector<SliceMode>& inputs() const;
const std::vector<SliceMode>& outputs() const;

private:
bool m_isSliceSupported = false;
std::vector<SliceMode> m_inputs;
std::vector<SliceMode> m_outputs;
};

} // namespace vpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/pass/graph_rewrite.hpp"

#include <memory>

namespace vpu {

class ExtractBatch: public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;

explicit ExtractBatch(std::unordered_set<ngraph::Node::type_info_t> targets);
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;

private:
std::unordered_set<ngraph::Node::type_info_t> targets;
};

} // namespace vpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/ngraph.hpp"
#include "batch_extraction_configuration.hpp"

namespace vpu {

SliceConfiguration sliceBinaryEltwise(const ngraph::Node& node);

} // namespace vpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/ngraph.hpp"
#include "batch_extraction_configuration.hpp"

namespace vpu {

SliceConfiguration sliceConvolution(const ngraph::Node& node);

} // namespace vpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/ngraph.hpp"
#include "batch_extraction_configuration.hpp"

namespace vpu {

SliceConfiguration sliceMatMul(const ngraph::Node& node);

} // namespace vpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/ngraph.hpp"
#include "batch_extraction_configuration.hpp"

namespace vpu {

SliceConfiguration sliceUnaryEltwise(const ngraph::Node& node);

} // namespace vpu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
#include "ngraph/node.hpp"
#include "ngraph/type/element_type.hpp"

#include "vpu/utils/error.hpp"

#include <stack>
#include <deque>

namespace vpu {

std::vector<std::int64_t> evaluateTargetShape(const ngraph::Output<ngraph::Node>& value);
Expand All @@ -15,6 +20,60 @@ std::shared_ptr<ngraph::Node> shapeToConstant(const ngraph::element::Type& type,

std::shared_ptr<ngraph::Node> gatherShapeElements(const ngraph::Output<ngraph::Node>&, int startIndex, size_t elemCount);

void printTo(std::ostream& stream, const ngraph::NodeTypeInfo& object);
template<>
inline void printTo(std::ostream& stream, const ngraph::NodeTypeInfo& object) {
stream << object.name << " ver. " << object.version;
}
Comment on lines -18 to +26
Copy link
Contributor Author

@ggladilov ggladilov Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still not sure why I had to make this change, but otherwise compiler picks up wrong printTo definition and prints empty string for ngraph::NodeTypeInfo - default printTo implementation


using Nodes = std::unordered_set<ngraph::Node*>;

template<class GetNext, class Visit>
Nodes dfs(ngraph::Node* root, GetNext&& getNext, Visit&& visit) {
Nodes visited;
std::stack<ngraph::Node*> stack{{root}};
This conversation was marked as resolved.
Show resolved Hide resolved
while (!stack.empty()) {
const auto current = stack.top();
stack.pop();

if (!visited.emplace(current).second) {
continue;
}

if (!visit(current)) {
continue;
}

for (const auto& next : getNext(current)) {
stack.push(next);
}
}
return visited;
}

template<class NumEntries, class Visit, class MoveForward>
void bfs(ngraph::Node* root, NumEntries&& getNumEntries, Visit&& visit, MoveForward&& moveForward) {
std::deque<ngraph::Node*> deque{root};
std::unordered_map<ngraph::Node*, std::size_t> visits;
while (!deque.empty()) {
const auto current = deque.front();
deque.pop_front();

const auto numEntries = current == root ? 1 : getNumEntries(current);

const auto visitsCount = ++visits[current];
VPU_THROW_UNLESS(visitsCount <= numEntries, "Encountered loop at {}", current);

if (visitsCount < numEntries) {
VPU_THROW_UNLESS(!deque.empty(), "Node {} should be visited only after all predecessors, but it is not available through all of them", current);
continue;
}

if (!visit(current)) {
continue;
}

moveForward(deque, current);
}
}

} // namespace vpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "vpu/utils/error.hpp"
#include "vpu/ngraph/transformations/extract_dynamic_batch/batch_extraction_configuration.hpp"

namespace vpu {

SliceConfiguration::SliceConfiguration(std::vector<SliceMode> inputs, std::vector<SliceMode> outputs)
: m_isSliceSupported(true)
, m_inputs(std::move(inputs))
, m_outputs(std::move(outputs)) {}

bool SliceConfiguration::isSliceSupported() const {
return m_isSliceSupported;
}

const std::vector<SliceMode>& SliceConfiguration::inputs() const {
VPU_THROW_UNLESS(m_isSliceSupported, "Encountered an attempt to access inputs slice configuration for a case when slice is unsupported");
return m_inputs;
}

const std::vector<SliceMode>& SliceConfiguration::outputs() const {
VPU_THROW_UNLESS(m_isSliceSupported, "Encountered an attempt to access outputs slice configuration for a case when slice is unsupported");
return m_outputs;
}

} // namespace vpu

Loading