Skip to content

Commit

Permalink
[GPU] Fix count non zero for empty input
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-paramuzov committed Mar 21, 2024
1 parent 5a0d71a commit b71cf4c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/non_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ struct count_nonzero_impl : typed_primitive_impl_ocl<count_nonzero> {
}
}

event::ptr execute_impl(const std::vector<event::ptr>& events, count_nonzero_inst& instance) override {
if (instance.get_impl_params()->input_layouts[0].count() == 0) {
// set count of non-zero elements to 0 in case if input tensor is empty to have correct memory alloc for gather_nonzero
return instance.output_memory(0).fill(instance.get_network().get_stream(), 0);
// return parent::execute_impl(events, instance);
} else {
return parent::execute_impl(events, instance);
}
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
return get_default_params<kernel_selector::count_nonzero_params>(impl_param, is_shape_agnostic);
}
Expand Down
39 changes: 39 additions & 0 deletions src/plugins/intel_gpu/tests/unit/test_cases/non_zero_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,42 @@ TEST(non_zero_gpu, const_input) {
ASSERT_FLOAT_EQ(output_ptr[i], out_data[i]);
}
}

TEST(non_zero_gpu, empty_input) {
auto& engine = get_test_engine();
auto in_layout = layout{ov::PartialShape{1, -1}, data_types::f32, format::bfyx};
auto in_data_layout = layout{ov::PartialShape{1, 0}, data_types::f32, format::bfyx};
auto input_data_mem = engine.allocate_memory(in_data_layout);

topology topology;
topology.add(input_layout("input", in_layout));
topology.add(count_nonzero("count_nonzero", input_info("input")));
topology.add(gather_nonzero("gather_nonzero", input_info("input"), input_info("count_nonzero")));

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::optimize_data(true));
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
network net(engine, topology, config);

net.set_input_data("input", input_data_mem);

auto count_nonzero_inst = net.get_primitive("count_nonzero");

// Put some value into out buffer to ensure that it's non empty
// That is needed to ensure that implementation correctly handles the cases when input tensor is empty and set count non zero to 0
count_nonzero_inst->output_memory(0).fill(engine.get_service_stream(), 1);
engine.get_service_stream().finish();

auto count_nonzero_impl = count_nonzero_inst->get_impl();
ASSERT_TRUE(count_nonzero_impl != nullptr);

auto gather_nonzero_inst = net.get_primitive("gather_nonzero");
auto gather_nonzero_impl = gather_nonzero_inst->get_impl();
ASSERT_TRUE(gather_nonzero_impl != nullptr);
ASSERT_TRUE(gather_nonzero_impl->is_dynamic());

auto outputs = net.execute();

auto output = outputs.at("gather_nonzero").get_memory();
ASSERT_EQ(output, nullptr);
}

0 comments on commit b71cf4c

Please sign in to comment.