Skip to content

Commit

Permalink
[GPU] Save use_onednn attribute in the blob
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-paramuzov committed Oct 17, 2024
1 parent 324a282 commit 9dbb8d8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
4 changes: 4 additions & 0 deletions src/plugins/intel_gpu/src/graph/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,7 @@ void program::save(cldnn::BinaryOutputBuffer& ob) const {

ob << _is_body_program;
ob << _can_be_optimized;
ob << get_layout_optimizer().get_optimization_attributes().use_onednn_impls;
processing_order.save(ob);

{
Expand Down Expand Up @@ -1895,6 +1896,9 @@ void program::load(cldnn::BinaryInputBuffer& ib) {

ib >> _is_body_program;
ib >> _can_be_optimized;
int32_t use_onednn_attr = 0;
ib >> use_onednn_attr;
get_layout_optimizer().set_optimization_attribute(layout_optimizer::optimization_attributes_type::use_onednn_impls, use_onednn_attr);
_loaded_from_cache = true;

processing_order.load(ib, *this);
Expand Down
51 changes: 38 additions & 13 deletions src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "intel_gpu/runtime/compilation_context.hpp"
#include "gemm_inst.h"
#include "permute_inst.h"
#include "layout_optimizer.h"

#include <cstddef>
#include <vector>
Expand Down Expand Up @@ -625,7 +626,7 @@ class gemm_gpu_tests: public ::testing::Test {
topology topology;
topology.add(input_layout("input1", in1_layout),
input_layout("input2", in2_layout),
gemm("gemm_ref", { input_info("input1"), input_info("input2") }, data_types::f16,
gemm("gemm_ref", { input_info("input1"), input_info("input2") }, data_types::f16,
{0, 2, 1, 3}, {0, 2, 3, 1}, {0, 1, 2, 3})
);

Expand All @@ -652,7 +653,7 @@ class gemm_gpu_tests: public ::testing::Test {
topology topology;
topology.add(input_layout("input1", in1_layout),
input_layout("input2", in2_layout),
gemm("gemm", { input_info("input1"), input_info("input2") }, data_types::f16,
gemm("gemm", { input_info("input1"), input_info("input2") }, data_types::f16,
{0, 2, 1, 3}, {0, 2, 3, 1}, {0, 1, 2, 3})
);

Expand Down Expand Up @@ -2789,7 +2790,7 @@ INSTANTIATE_TEST_SUITE_P(gemm_gpu, gemm_onednn_ndims, ::testing::ValuesIn(std::v

class gemm_onednn: public ::testing::Test {
public:
void test_impl_replacement_with_cldnn() {
void test_impl_replacement_with_cldnn(bool is_caching_test) {
auto& engine = get_test_engine();

if (!engine.get_device_info().supports_immad)
Expand Down Expand Up @@ -2828,16 +2829,34 @@ class gemm_onednn: public ::testing::Test {
ov::intel_gpu::optimize_data(true),
ov::intel_gpu::allow_new_shape_infer(true) };

network network(engine, topology, cfg);
network.set_input_data("input1", input1);
network.set_input_data("input2", input2);
cldnn::network::ptr network;
if (is_caching_test) {
membuf mem_buf;
{
std::ostream out_mem(&mem_buf);
BinaryOutputBuffer ob = BinaryOutputBuffer(out_mem);
ob.set_stream(get_test_stream_ptr().get());
program::build_program(engine, topology, cfg)->save(ob);
}
{
std::istream in_mem(&mem_buf);
BinaryInputBuffer ib = BinaryInputBuffer(in_mem, engine);
auto imported_prog = std::make_shared<cldnn::program>(engine, cfg);
imported_prog->load(ib);
network = std::make_shared<cldnn::network>(imported_prog);
}
} else {
network = std::make_shared<cldnn::network>(engine, topology, cfg);
}
network->set_input_data("input1", input1);
network->set_input_data("input2", input2);

auto inst = network.get_primitive("gemm");
auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());

auto outputs = network.execute();
auto outputs = network->execute();

auto output = outputs.at("gemm").get_memory();
cldnn::mem_lock<ov::float16> output_ptr(output, get_test_stream());
Expand All @@ -2847,12 +2866,15 @@ class gemm_onednn: public ::testing::Test {
ASSERT_FLOAT_EQ(output_ptr[i], out_data[i]);
}

// WA: Call wait_all() to wait for all queued kernels compilation finish
network.get_program()->get_compilation_context().wait_all();
// Call wait_all() to wait for all queued kernels compilation finish
network->get_program()->get_compilation_context().wait_all();

auto& lo = network->get_program()->get_layout_optimizer();
ASSERT_TRUE(lo.get_optimization_attributes().use_onednn_impls);

// Check if OneDNN's impl is used for the next execute() call
network.execute();
inst = network.get_primitive("gemm");
network->execute();
inst = network->get_primitive("gemm");
impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_FALSE(impl->is_dynamic());
Expand Down Expand Up @@ -3214,7 +3236,10 @@ class gemm_onednn: public ::testing::Test {
};

TEST_F(gemm_onednn, impl_replacement_with_cldnn) {
this->test_impl_replacement_with_cldnn();
this->test_impl_replacement_with_cldnn(false);
}
TEST_F(gemm_onednn, impl_replacement_with_cldnn_cached) {
this->test_impl_replacement_with_cldnn(true);
}

// Check gemm_onednn transpose_format() can accept transpose white list format (byfx/bxfy)
Expand Down

0 comments on commit 9dbb8d8

Please sign in to comment.