Skip to content

Commit

Permalink
[GPU] Need to exclude fused mem_dep from shape_infer_dep
Browse files Browse the repository at this point in the history
  • Loading branch information
kelvinchoi-intel committed Mar 20, 2023
1 parent dcc8a36 commit 2f6fbd7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/plugins/intel_gpu/src/graph/include/program_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ struct program_node {
return false;
}

std::map<size_t, memory::ptr> get_const_memory_deps() const;
bool is_fused_dep(size_t dep_idx) const;

std::map<size_t, memory::ptr> get_const_memory_deps(bool exclude_fused_dep = false) const;

virtual std::unique_ptr<kernel_impl_params> get_kernel_impl_params() const {
return get_kernel_impl_params(get_input_layouts(), output_layouts);
Expand Down
6 changes: 5 additions & 1 deletion src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void primitive_inst::update_shape() {
set_shape_change();

// Even though the predecessors' shapes are not changed, the output shape might be udpated by the mem_dep
auto memory_deps = _node->get_const_memory_deps();
auto memory_deps = _node->get_const_memory_deps(true);
for (auto& i : _node->get_shape_infer_dependencies()) {
if (memory_deps.count(i) > 0) {
continue;
Expand All @@ -190,6 +190,10 @@ void primitive_inst::update_shape() {
}
auto& dep = _node->get_dependency(i);
auto dep_id = dep.id();
// exclude fused node from memory_deps
if (_node->is_fused_dep(i)) {
break;
}
// Events may be not created for in-order queue, so take them for OOO queue only
if (_network.has_event(dep.id()) && queue_type == QueueTypes::out_of_order) {
dependencies_events.push_back(_network.get_primitive_event(dep_id));
Expand Down
20 changes: 19 additions & 1 deletion src/plugins/intel_gpu/src/graph/program_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,31 @@ bool program_node::has_padded_dependency() const {
});
}

std::map<size_t, memory::ptr> program_node::get_const_memory_deps() const {
bool program_node::is_fused_dep(size_t dep_idx) const {
for (auto fused : get_fused_primitives()) {
if (dep_idx >= fused.dep_start_idx) {
return true;
}
}

return false;
}

std::map<size_t, memory::ptr> program_node::get_const_memory_deps(bool exclude_fused_dep) const {
std::map<size_t, memory::ptr> mem_deps;
for (auto& i : get_shape_infer_dependencies()) {
// Some primitives may have flexible count of deps (e.g. reshape), thus allow skipping some deps
if (i >= get_dependencies().size())
continue;

// exclude fused dependency
if (exclude_fused_dep) {
if (is_fused_dep(i)) {
continue;
}
}

// constant type only
auto& dep = get_dependency(i);
if (dep.is_type<data>()) {
mem_deps.insert({i, dep.as<data>().get_attached_memory_ptr()});
Expand Down

0 comments on commit 2f6fbd7

Please sign in to comment.