Skip to content

Commit

Permalink
Skip mark_node in shape flow for broadcast node if dependency nodes a…
Browse files Browse the repository at this point in the history
…re data and shape_of

Signed-off-by: yuan.xiong <[email protected]>
  • Loading branch information
yuanxion committed Dec 12, 2024
1 parent 19d048e commit 457a67a
Showing 1 changed file with 6 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,119 +16,22 @@

using namespace cldnn;

namespace {
bool has_input_layout_dep(const std::vector<std::pair<cldnn::program_node*, int>>& shape_of_deps) {
for (auto& shape_of_dep : shape_of_deps) {
// input_layout node
if (shape_of_dep.first->is_type<input_layout>()) {
return true;
}
}
return false;
}

bool has_shape_of_dep(const std::vector<std::pair<cldnn::program_node*, int>>& broadcast_deps) {
for (auto& broadcast_dep : broadcast_deps) {
// shape_of node
if (broadcast_dep.first->is_type<shape_of>()) {
auto& shape_of_deps = broadcast_dep.first->get_dependencies();
return has_input_layout_dep(shape_of_deps);
}
}
return false;
}

bool has_broadcast_dep(const std::vector<std::pair<cldnn::program_node*, int>>& reorder_deps) {
for (auto& reorder_dep : reorder_deps) {
// broadcast node
if (reorder_dep.first->is_type<broadcast>()) {
auto& broadcast_deps = reorder_dep.first->get_dependencies();
return has_shape_of_dep(broadcast_deps);
}
}
return false;
}

bool has_reorder_reoder_dep(const std::vector<std::pair<cldnn::program_node*, int>>& eltwise_deps) {
for (auto& eltwise_dep : eltwise_deps) {
// reorder node (reorder -> eltwise)
if (eltwise_dep.first->is_type<reorder>()) {
auto& eltwise_dep_reorder_deps = eltwise_dep.first->get_dependencies();

for (auto& eltwise_dep_reorder_dep : eltwise_dep_reorder_deps) {
// reorder node (broadcast -> reorder)
if (eltwise_dep_reorder_dep.first->is_type<reorder>()) {
auto& reorder_dep_reorder_deps = eltwise_dep_reorder_dep.first->get_dependencies();
return has_broadcast_dep(reorder_dep_reorder_deps);
}
}
}
}
return false;
}

bool has_eltwise_dep(const std::vector<std::pair<cldnn::program_node*, int>>& reorder_deps) {
for (auto& reorder_dep : reorder_deps) {
// eltwise node
if (reorder_dep.first->is_type<eltwise>()) {
auto& eltwise_deps = reorder_dep.first->get_dependencies();
return has_reorder_reoder_dep(eltwise_deps);
}
}
return false;
}

bool has_reorder_dep(const std::vector<std::pair<cldnn::program_node*, int>>& conv_deps) {
for (auto& conv_dep : conv_deps) {
//if (conv_dep.first->id().find(dequantize_name) != std::string::npos) {

// reorder node ( reorder -> convolution)
if (conv_dep.first->is_type<reorder>()) {
auto& reorder_deps = conv_dep.first->get_dependencies();
return has_eltwise_dep(reorder_deps);
}
}
return false;
}

bool has_convolution_dep(const std::vector<std::pair<cldnn::program_node*, int>>& dependencies) {
for (auto& dependency : dependencies) {
// convolution node
if (dependency.first->is_type<convolution>()) {
auto& conv_deps = dependency.first->get_dependencies();
return has_reorder_dep(conv_deps);
}
}
return false;
}

// check dependencies for reorder node added for convolution in quantized model
bool skip_quantization_conv_reorder(const program_node& node) {
// reorder -> convolution -> reorder -> eltwise -> reorder -> reorder -> broadcast -> shape_of -> input_layout
if (!node.is_type<reorder>()) {
return false;
}

auto& dependencies = node.get_dependencies();
return has_convolution_dep(dependencies);
}

} // namespace

void mark_shape_of_subgraphs::look_for_shape_of_subgraph(program_node& node) {
if (node.is_type<shape_of>()) {
mark_node(node);
return;
}

// skip mark_node for reorder node (after convolution node) for quantized model
if (skip_quantization_conv_reorder(node)) {
return;
}
// skip mark_node for broadcast node if dependency nodes are data and shape_of
auto& dependencies = node.get_dependencies();
if (node.is_type<broadcast>() && dependencies.size() == 2) {
if (dependencies[0].first->is_type<data>() && dependencies[1].first->is_type<shape_of>())
return;

// Check if all dependencies are constant or marked as a part of shape_of subgraph
bool can_execute_in_subgraph = true;
bool has_shape_of_subgraph_dep = false;

for (auto& dependency : node.get_dependencies()) {
if (dependency.first->is_in_shape_of_subgraph()) {
has_shape_of_subgraph_dep = true;
Expand Down

0 comments on commit 457a67a

Please sign in to comment.