Skip to content

Commit

Permalink
Fix reviewed codes.
Browse files Browse the repository at this point in the history
  • Loading branch information
yunji-yunji committed Jul 22, 2021
1 parent 6d36bb5 commit 2aa05b2
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ REGISTER_FACTORY(v6, GatherElements);

// ------------------------------ Supported v7 ops ------------------------------ //
REGISTER_FACTORY(v7, Gather);
// REGISTER_FACTORY(v7, GatherElements);

// ------------------------------ Supported v8 ops ------------------------------ //
REGISTER_FACTORY(v8, Gather);
Expand Down
14 changes: 7 additions & 7 deletions inference-engine/src/cldnn_engine/ops/gather_elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static cldnn::gather_elements::gather_elements_axis GetGatherElementsAxis(int ax
case 3: return cldnn::gather_elements::gather_elements_axis::along_y;
case 4: return cldnn::gather_elements::gather_elements_axis::along_z;
case 5: return cldnn::gather_elements::gather_elements_axis::along_w;
default: IE_THROW() << "Unsupported ScatterElementsUpdate axis: " << axis;
default: IE_THROW() << "Unsupported GatherElements axis: " << axis;
}
return cldnn::gather_elements::gather_elements_axis::along_f; // shouldn't get here
}
Expand All @@ -51,16 +51,16 @@ void CreateGatherElementsOp(Program& p, const std::shared_ptr<ngraph::op::v6::Ga
auto outLayout = DefaultFormatForDims(op->get_output_shape(0).size());

auto primitive = cldnn::gather_elements(layerName,
inputPrimitives[0],
inputPrimitives[1],
outLayout,
CldnnTensorFromIEDims(op->get_output_shape(0)),
GetGatherElementsAxis(axis, rank));
inputPrimitives[0],
inputPrimitives[1],
outLayout,
CldnnTensorFromIEDims(op->get_output_shape(0)),
GetGatherElementsAxis(axis, rank));

p.AddPrimitive(primitive);
p.AddPrimitiveToProfiler(op);
}

REGISTER_FACTORY_IMPL(v6, GatherElements);

} // namespace CLDNNPlugin
} // namespace CLDNNPlugin
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
/*
// Copyright (c) 2021 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
Expand Down Expand Up @@ -52,7 +40,6 @@ struct gather_elements : public primitive_base<gather_elements> {
const primitive_id& indices,
const format& output_format,
const tensor& output_shape,
// const uint8_t axis = 0,
const gather_elements_axis axis,
const padding& output_padding = padding())
: primitive_base(id, {data, indices}, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ ParamsKey GatherElementsKernelRef::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include "include/data_types.cl"
#include "include/fetch_data.cl"

#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
#define GET_OUTPUT_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)

KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data,
KERNEL(gather_elements_ref)(const __global INPUT0_TYPE* data,
const __global INPUT1_TYPE* indices,
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
Expand Down Expand Up @@ -39,7 +39,7 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data,
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;

const int out_idx = GET_UPDATES_INDEX(INPUT1, ORDER);
const int out_idx = GET_OUTPUT_INDEX(INPUT1, ORDER);

#if INPUT1_DIMS == 4
size_t data_shape[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
Expand Down Expand Up @@ -70,10 +70,7 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data,

size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data;
size_t inner_sum = out_idx % max_inner_sum;
if (indices[out_idx] < 0 || indices[out_idx] >= data_shape[AXIS]) {
printf("indices values of GatherElement exceed data size.\n");
return;
}

uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum;
INPUT0_TYPE val = data[idx];

Expand All @@ -85,4 +82,4 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data,
#endif
}

#undef GET_UPDATES_INDEX
#undef GET_OUTPUT_INDEX
19 changes: 2 additions & 17 deletions inference-engine/thirdparty/clDNN/src/gather_elements.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
/*
// Copyright (c) 2021 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

#include "gather_elements_inst.h"

Expand Down Expand Up @@ -40,13 +28,10 @@ layout gather_elements_inst::calc_output_layout(gather_elements_node const& node
input_layout_origin.data_type = node.get_fused_output_layout().data_type;
}

// const size_t input_dims = input_layout.size();
auto output_type = indices_layout_origin.data_type;
auto output_format = op->output_format;
auto output_shape = op->output_shape;

// const auto axis = op->axis;

// calculate initial output shape
return layout(output_type, output_format, output_shape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,18 @@ attach_gather_elements_impl::attach_gather_elements_impl() {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::i32, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::i32, format::bfzyx),
std::make_tuple(data_types::i8, format::bfzyx),
std::make_tuple(data_types::u8, format::bfzyx),
std::make_tuple(data_types::f32, format::bfwzyx),
std::make_tuple(data_types::f16, format::bfwzyx),
std::make_tuple(data_types::i32, format::bfwzyx),
std::make_tuple(data_types::i8, format::bfwzyx),
std::make_tuple(data_types::u8, format::bfwzyx),
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ inline void DoTest(engine& engine,
network.set_input_data("InputIndices", input1);
auto outputs = network.execute();
auto output = outputs.at("gather_elements").get_memory();
// auto output_ptr = output.pointer<uint16_t>();
cldnn::mem_lock<uint16_t> output_ptr(output, get_test_stream());

for (size_t i = 0; i < expected_results.size(); ++i) {
Expand Down

0 comments on commit 2aa05b2

Please sign in to comment.