Skip to content

Commit

Permalink
review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-paramuzov committed Sep 17, 2024
1 parent c6d157b commit efa1a06
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void compile_graph::run(program& p) {
try {
const auto& params = node->get_kernel_impl_params();
auto shape_type = ImplementationManager::get_shape_type(*params);
auto selected_impl_manager = node->type()->choose_impl(*node, *node->get_kernel_impl_params(), shape_type);
auto selected_impl_manager = node->type()->choose_impl(*node, shape_type);
std::string fail_reason = "";
try {
if (selected_impl_manager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void select_preferred_formats::run(program& p) {
auto factory = test_format<std::shared_ptr<ImplementationManager>>(*n, format::any,
[&shape_type](program_node& n) {
return test_no_input_pad<std::shared_ptr<ImplementationManager>>(n, [&shape_type](program_node& n) {
return n.type()->choose_impl(n, *n.get_kernel_impl_params(), shape_type);
return n.type()->choose_impl(n, shape_type);
});
});

Expand Down
4 changes: 1 addition & 3 deletions src/plugins/intel_gpu/src/graph/include/primitive_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ struct primitive_type {
const program_node& node) const = 0;

virtual std::unique_ptr<primitive_impl> create_impl(const program_node& node) const = 0;
virtual std::shared_ptr<ImplementationManager> choose_impl(const program_node& node,
const kernel_impl_params& params,
shape_types shape_type) const = 0;
virtual std::shared_ptr<ImplementationManager> choose_impl(const program_node& node, shape_types shape_type) const = 0;

virtual std::set<impl_types> get_available_impl_types(const program_node& node) const = 0;
virtual std::vector<std::shared_ptr<ImplementationManager>> get_supported_implementations(const program_node& node) const = 0;
Expand Down
6 changes: 2 additions & 4 deletions src/plugins/intel_gpu/src/graph/include/primitive_type_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ struct primitive_type_base : primitive_type {
return {};
}

std::shared_ptr<ImplementationManager> choose_impl(const program_node& node,
const kernel_impl_params& runtime_params,
shape_types requested_shape_type) const override {
std::shared_ptr<ImplementationManager> choose_impl(const program_node& node, shape_types requested_shape_type) const override {
OPENVINO_ASSERT(node.type() == this, "[GPU] primitive_type_base::choose_impl: primitive type mismatch");
for (auto& impl : get_supported_implementations(node)) {
impl_types impl_type = impl->get_impl_type();
Expand All @@ -68,7 +66,7 @@ struct primitive_type_base : primitive_type {
std::unique_ptr<primitive_impl> create_impl(const program_node& node) const override {
OPENVINO_ASSERT(node.type() == this, "[GPU] primitive_type_base::create_impl: primitive type mismatch");
const auto params = node.get_kernel_impl_params();
auto impl = choose_impl(node, *params, ImplementationManager::get_shape_type(*params));
auto impl = choose_impl(node, ImplementationManager::get_shape_type(*params));

const auto& p = node.get_primitive();
OPENVINO_ASSERT(impl != nullptr, "[GPU] Can't choose implementation for ", node.id(), " node (type=", p->type_string(), ")\n",
Expand Down
5 changes: 3 additions & 2 deletions src/plugins/intel_gpu/src/graph/include/program_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,12 @@ inline RT test_format(program_node& node, format fmt, std::function<RT(program_n
node.recalc_output_layouts(false);

bool has_deps = !node.get_dependencies().empty();
layout prev_input_layout = has_deps ? node.get_input_layout(0) : layout();
layout prev_input_layout = layout();
if (has_deps) {
auto dep_with_port = node.get_dependency_with_port(0);
prev_input_layout = dep_with_port.first->get_output_layout(false, dep_with_port.second);
auto new_layout = prev_input_layout;
set_format_no_any(new_layout, fmt);
auto dep_with_port = node.get_dependency_with_port(0);
dep_with_port.first->set_output_layout(new_layout, false, dep_with_port.second);
}

Expand Down
7 changes: 3 additions & 4 deletions src/plugins/intel_gpu/src/graph/layout_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1199,13 +1199,12 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
if (forced_impl != impl_types::any)
return forced_impl;

const auto params = node.get_kernel_impl_params();
auto shape_type = shape_types::any;

auto impl = test_format<std::shared_ptr<ImplementationManager>>(node, preferred_format,
[&shape_type, &params](program_node& n) {
return test_no_input_pad<std::shared_ptr<ImplementationManager>>(n, [&shape_type, &params](program_node& n) {
return n.type()->choose_impl(n, *params, shape_type);
[&shape_type](program_node& n) {
return test_no_input_pad<std::shared_ptr<ImplementationManager>>(n, [&shape_type](program_node& n) {
return n.type()->choose_impl(n, shape_type);
});
});

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2510,7 +2510,7 @@ std::shared_ptr<primitive_impl> ImplementationsFactory::get_primitive_impl_for_p
auto kernels = _program.get_kernels_cache().compile(updated_params, impl->get_kernels_source());
impl->set_kernels(kernels);
}
cache.add(updated_params, impl->clone());
cache.add(updated_params, std::move(impl));
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TEST(kernel_impl_params_relevance, weights_layout) {
network.set_input_data("input", actual_input_data);

// 2. Force reference `fully_connected_gpu_bfyx_ref` kernel impl before execution,
// so during _node->type()->choose_impl(*_node, updated_params); call for static kernel version reference
// so during _node->type()->choose_impl(*_node); call for static kernel version reference
// impl will be used. Call execute() to trigger desired kernel compilation
auto fc_ref_impl = ov::intel_gpu::ImplementationDesc(format::bfyx, "fully_connected_gpu_bfyx_ref", impl_types::ocl);
auto force_impl_prop = ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"fc", fc_ref_impl} });
Expand Down

0 comments on commit efa1a06

Please sign in to comment.