Skip to content

Commit

Permalink
Fix reviewed code and remove i8 and u8 data type.
Browse files Browse the repository at this point in the history
  • Loading branch information
yunji-yunji committed Jul 26, 2021
1 parent 306c093 commit b36cb1a
Show file tree
Hide file tree
Showing 14 changed files with 26 additions and 54 deletions.
4 changes: 2 additions & 2 deletions inference-engine/src/cldnn_engine/ops/gather_elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace CLDNNPlugin {

static cldnn::gather_elements::gather_elements_axis GetGatherElementsAxis(int axis, unsigned rank) {
static cldnn::gather_elements::gather_elements_axis GetGatherAxis(int axis, unsigned rank) {
if (axis < 0)
axis += rank;
if (axis < 0 || axis >= rank)
Expand Down Expand Up @@ -55,7 +55,7 @@ void CreateGatherElementsOp(Program& p, const std::shared_ptr<ngraph::op::v6::Ga
inputPrimitives[1],
outLayout,
CldnnTensorFromIEDims(op->get_output_shape(0)),
GetGatherElementsAxis(axis, rank));
GetGatherAxis(axis, rank));

p.AddPrimitive(primitive);
p.AddPrimitiveToProfiler(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ const std::vector<InferenceEngine::Precision> dPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
InferenceEngine::Precision::I16,
InferenceEngine::Precision::U8,
InferenceEngine::Precision::I8
};
const std::vector<InferenceEngine::Precision> iPrecisions = {
InferenceEngine::Precision::I32,
Expand Down Expand Up @@ -74,4 +72,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_set5, GatherElementsLayerTest,
::testing::ValuesIn(iPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
GatherElementsLayerTest::getTestCaseName);
} // namespace
} // namespace
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
// //
//

#include <vector>
#include <ngraph/opsets/opset6.hpp>
Expand All @@ -17,6 +17,8 @@ const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I8,
InferenceEngine::Precision::U8,
};

const std::vector<InferenceEngine::Precision> idxPrecisions = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ struct gather_elements : public primitive_base<gather_elements> {
/// @}
/// @}
/// @}
} // namespace cldnn
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -514,18 +514,6 @@ enum class GatherAxis {
BATCH,
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// GatherElementsAxis
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
enum class GatherElementsAxis {
X,
Y,
Z,
W,
FEATURE,
BATCH,
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// ScatterUpdateAxis
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ static size_t GetGatherElementsChannelIndex(const gather_elements_params& params
size_t inputSize = params.inputs[0].GetDims().size();

switch (params.axis) {
case GatherElementsAxis::X:
case GatherAxis::X:
return inputSize - 1;
case GatherElementsAxis::Y:
case GatherAxis::Y:
return inputSize - 2;
case GatherElementsAxis::Z:
case GatherAxis::Z:
return inputSize - 3;
case GatherElementsAxis::W:
case GatherAxis::W:
return 2;
case GatherElementsAxis::FEATURE:
case GatherAxis::FEATURE:
return 1;
case GatherElementsAxis::BATCH:
case GatherAxis::BATCH:
return 0;
default:
break;
Expand All @@ -38,13 +38,9 @@ ParamsKey GatherElementsKernelRef::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
Expand Down Expand Up @@ -152,4 +148,7 @@ KernelsData GatherElementsKernelRef::GetKernelsData(const Params& params, const
return { kd };
}

KernelsPriority GatherElementsKernelRef::GetKernelsPriority(const Params& /*params*/, const optional_params& /*options*/) const {
return DONT_USE_IF_HAVE_SOMETHING_ELSE;
}
} // namespace kernel_selector
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ namespace kernel_selector {
// gather_elements_params
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct gather_elements_params : public base_params {
gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherElementsAxis::BATCH) {}
gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherAxis::BATCH) {}

GatherElementsAxis axis;
GatherAxis axis;
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -30,6 +30,7 @@ class GatherElementsKernelRef : public KernelBaseOpenCL {
virtual JitConstants GetJitConstants(const gather_elements_params& params) const;
virtual CommonDispatchData SetDefault(const gather_elements_params& params, const optional_params&) const;
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::QUANTIZE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ gather_elements_kernel_selector::gather_elements_kernel_selector() { Attach<Gath
KernelsData gather_elements_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
return GetNaiveBestKernel(params, options, KernelType::GATHER_ELEMENTS);
}
} // namespace kernel_selector
} // namespace kernel_selector
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ class gather_elements_kernel_selector : public kernel_selector_base {

KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
};
} // namespace kernel_selector
} // namespace kernel_selector
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ KERNEL(gather_elements_ref)(const __global INPUT0_TYPE* data,
#endif
}

#undef ORDER
#undef GET_OUTPUT_INDEX
Original file line number Diff line number Diff line change
Expand Up @@ -402,24 +402,14 @@ std::string toString(GatherAxis a) {
switch (a) {
case GatherAxis::X: return "X";
case GatherAxis::Y: return "Y";
case GatherAxis::Z: return "Z";
case GatherAxis::W: return "W";
case GatherAxis::FEATURE: return "FEATURE";
case GatherAxis::BATCH: return "BATCH";
default: return "";
}
}

std::string toString(GatherElementsAxis a) {
switch (a) {
case GatherElementsAxis::X: return "X";
case GatherElementsAxis::Y: return "Y";
case GatherElementsAxis::Z: return "Z";
case GatherElementsAxis::W: return "W";
case GatherElementsAxis::FEATURE: return "FEATURE";
case GatherElementsAxis::BATCH: return "BATCH";
default: return "";
}
}

std::string toString(ScatterUpdateAxis a) {
switch (a) {
case ScatterUpdateAxis::X: return "X";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ std::string toString(MVNEpsMode mode);
std::string toString(WeightsLayout layout);
std::string toString(ConcatAxis a);
std::string toString(GatherAxis a);
std::string toString(GatherElementsAxis a);
std::string toString(ScatterUpdateAxis a);
std::string toString(ResampleType type);
std::string toString(CoordinateTransformationMode mode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ kernel_selector::gather_elements_axis convert_axis(gather_elements::gather_eleme
case gather_elements::along_b:
return kernel_selector::gather_elements_axis::BATCH;
default:
return kernel_selector::gather_elements_axis::X;
return kernel_selector::gather_elements_axis::BATCH;
}
}

Expand Down Expand Up @@ -72,18 +72,12 @@ attach_gather_elements_impl::attach_gather_elements_impl() {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::i32, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::i32, format::bfzyx),
std::make_tuple(data_types::i8, format::bfzyx),
std::make_tuple(data_types::u8, format::bfzyx),
std::make_tuple(data_types::f32, format::bfwzyx),
std::make_tuple(data_types::f16, format::bfwzyx),
std::make_tuple(data_types::i32, format::bfwzyx),
std::make_tuple(data_types::i8, format::bfwzyx),
std::make_tuple(data_types::u8, format::bfwzyx),
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ using shape_calculation_mode = kernel_selector::ShapeCalculationMode;
using interpolate_axis = kernel_selector::InterpolateAxis;
using border_type = kernel_selector::BorderType;
using gather_axis = kernel_selector::GatherAxis;
using gather_elements_axis = kernel_selector::GatherElementsAxis;
using gather_elements_axis = kernel_selector::GatherAxis;
using scatter_update_axis = kernel_selector::ScatterUpdateAxis;
using reduce_mode = kernel_selector::ReduceMode;
using cum_sum_axis = kernel_selector::CumSumAxis;
Expand Down

0 comments on commit b36cb1a

Please sign in to comment.