Skip to content

Commit

Permalink
[GPU] New shape infer baseline (openvinotoolkit#12241)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-paramuzov authored Aug 15, 2022
1 parent d04521a commit b33f22c
Show file tree
Hide file tree
Showing 40 changed files with 799 additions and 189 deletions.
1 change: 1 addition & 0 deletions src/core/include/openvino/op/util/gather_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class OPENVINO_API GatherBase : public Op {

bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
const int64_t& get_batch_dims() const;
void set_batch_dims(int64_t batch_dims);

protected:
int64_t m_batch_dims = 0;
Expand Down
35 changes: 5 additions & 30 deletions src/core/shape_inference/include/shape_nodes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void shape_infer(const ov::opset1::Reshape* op, const std::vector<T> &input_shap
output_shape.resize(output_pattern.size());

auto output_rank = input_shapes[1].size() == 0 ? 0 : input_shapes[1][0];
if (output_rank == 0 && !output_shape.empty()) {
if (output_rank == 0 && output_shape.size() != 0) {
output_pattern.clear();
OPENVINO_ASSERT(output_pattern.size() == 1);
NODE_VALIDATION_CHECK(op,
Expand Down Expand Up @@ -56,7 +56,7 @@ void shape_infer(const ov::opset1::Reshape* op, const std::vector<T> &input_shap
}
size_t input_product(1);
for (size_t i = 0; i < input_shape.size(); ++i) {
if (i < static_cast<int64_t>(output_pattern.size()) && output_pattern[i] == 0)
if (i < static_cast<size_t>(output_pattern.size()) && output_pattern[i] == 0)
continue;
input_product = input_shape[i].get_length() * input_product;
}
Expand Down Expand Up @@ -93,13 +93,6 @@ void shape_infer(const ov::opset1::Reshape* op, const std::vector<T> &input_shap
input_shape);
}


template<>
void shape_infer<ov::PartialShape>(const ov::opset1::Reshape* op, const std::vector<ov::PartialShape> &input_shapes, std::vector<ov::PartialShape> &output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
OPENVINO_UNREACHABLE("Reshape shape inference is not yet unified for use with PartialShapes");
}

template<class T>
void shape_infer(const ov::opset1::Squeeze* op, const std::vector<T> &input_shapes, std::vector<T> &output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
Expand All @@ -111,7 +104,7 @@ void shape_infer(const ov::opset1::Squeeze* op, const std::vector<T> &input_shap
auto& input_shape = input_shapes[0];
OPENVINO_ASSERT(input_shape.is_static());
auto& output_shape = output_shapes[0];
output_shape.clear();
output_shape = T{};

ov::normalize_axes(op, input_shape.rank().get_length(), axes);

Expand All @@ -124,14 +117,6 @@ void shape_infer(const ov::opset1::Squeeze* op, const std::vector<T> &input_shap
}
}


template<>
void shape_infer<ov::PartialShape>(const ov::opset1::Squeeze* op,
const std::vector<ov::PartialShape> &input_shapes, std::vector<ov::PartialShape> &output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
OPENVINO_UNREACHABLE("Squeeze shape inference is not yet unified for use with PartialShapes");
}

template<class T>
void shape_infer(const ov::opset1::Unsqueeze* op,
const std::vector<T> &input_shapes, std::vector<T> &output_shapes,
Expand All @@ -148,25 +133,16 @@ void shape_infer(const ov::opset1::Unsqueeze* op,

NODE_VALIDATION_CHECK(op, !axes.empty(), "'axes' input is mandatory");

auto expanded_rank = input_shape.size() + axes.size();
int64_t expanded_rank = input_shape.size() + axes.size();
ov::normalize_axes(op, static_cast<int64_t>(expanded_rank), axes);

std::set<int64_t> unique_sorted_axes(axes.begin(), axes.end());
for (const auto& axis : unique_sorted_axes) {
NODE_VALIDATION_CHECK(op, axis <= expanded_rank, "provided 'axes' value ", axis, " is not valid.");
output_shape.insert(next(begin(output_shape), axis), 1);
output_shape.insert(next(output_shape.begin(), axis), 1);
}
}


template<>
void shape_infer<ov::PartialShape>(const ov::opset1::Unsqueeze* op,
const std::vector<ov::PartialShape> &input_shapes, std::vector<ov::PartialShape> &output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
OPENVINO_UNREACHABLE("Unsqueeze shape inference is not yet unified for use with PartialShapes");
}


template <class T>
inline void dynamic_shape(T& output_shape) {
OPENVINO_UNREACHABLE("This code should be executed only for PartialShape class");
Expand Down Expand Up @@ -209,4 +185,3 @@ void shape_infer(const ov::opset3::ShapeOf* op,
NODE_VALIDATION_CHECK(op, input_shapes.size() == 1 && output_shapes.size() == 1);
shape_of_shape_infer(input_shapes[0], output_shapes[0]);
}

4 changes: 4 additions & 0 deletions src/core/src/op/util/gather_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ const int64_t& ov::op::util::GatherBase::get_batch_dims() const {
return m_batch_dims;
}

void ov::op::util::GatherBase::set_batch_dims(int64_t batch_dims) {
m_batch_dims = batch_dims;
}

namespace gather {
namespace {
template <ov::element::Type_t ET>
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/include/intel_gpu/graph/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ struct program {
std::set<std::shared_ptr<program_node>> const& nodes,
build_options const& options,
bool is_internal);
explicit program(engine& engine);
~program();
engine& get_engine() const { return _engine; }
const build_options& get_options() const { return options; }
Expand Down
39 changes: 38 additions & 1 deletion src/plugins/intel_gpu/include/intel_gpu/primitives/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,47 @@ struct reshape : public primitive_base<reshape> {
const tensor& output_shape,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
: primitive_base(id, {input}, ext_prim_id, output_padding), output_shape(output_shape) {}
: primitive_base(id, {input}, ext_prim_id, output_padding)
, output_shape(output_shape)
, output_pattern({})
, output_partial_shape({}) {}

/// @brief reshape with dynamic pattern
reshape(const primitive_id& id,
const primitive_id& input,
const primitive_id& pattern_id,
bool special_zero,
const ov::PartialShape& output_partial_shape,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
: primitive_base(id, {input, pattern_id}, ext_prim_id, output_padding)
, output_shape(tensor())
, special_zero(special_zero)
, output_pattern({})
, output_partial_shape(output_partial_shape) {}

/// @brief reshape with static pattern
reshape(const primitive_id& id,
const primitive_id& input,
bool special_zero,
const std::vector<int64_t>& output_pattern,
const ov::PartialShape& output_partial_shape,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
: primitive_base(id, {input}, ext_prim_id, output_padding)
, output_shape(tensor())
, special_zero(special_zero)
, output_pattern(output_pattern)
, output_partial_shape(output_partial_shape) {}

/// @brief Requested memory shape.
tensor output_shape;

bool special_zero = false;

std::vector<int64_t> output_pattern;

ov::PartialShape output_partial_shape;
};

/// @}
Expand Down
57 changes: 56 additions & 1 deletion src/plugins/intel_gpu/include/intel_gpu/runtime/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <set>

#include <openvino/core/partial_shape.hpp>
#include <openvino/core/type/element_type.hpp>

namespace cldnn {
/// @addtogroup cpp_api C++ API
Expand Down Expand Up @@ -116,6 +117,8 @@ struct data_type_traits {

static std::string name(data_types data_type) {
switch (data_type) {
case data_types::bin:
return "bin";
case data_types::i8:
return "i8";
case data_types::u8:
Expand All @@ -130,7 +133,7 @@ struct data_type_traits {
return "f32";
default:
assert(0);
return std::string("invalid data type: " + std::to_string(static_cast<int>(data_type)));
return "unknown (" + std::to_string(typename std::underlying_type<data_types>::type(data_type)) + ")";
}
}

Expand Down Expand Up @@ -209,6 +212,55 @@ bool data_type_match(data_types data_type) {
return data_type == type_to_data_type<T>::value;
}

inline data_types data_type_to_element_type(ov::element::Type t) {
switch (t) {
case ov::element::Type_t::i16:
case ov::element::Type_t::u16:
case ov::element::Type_t::f32:
case ov::element::Type_t::f64:
return cldnn::data_types::f32;
case ov::element::Type_t::f16:
return cldnn::data_types::f16;
case ov::element::Type_t::u8:
return cldnn::data_types::u8;
case ov::element::Type_t::i8:
return cldnn::data_types::i8;
case ov::element::Type_t::i32:
case ov::element::Type_t::u32:
case ov::element::Type_t::u64:
return cldnn::data_types::i32;
case ov::element::Type_t::i64:
return cldnn::data_types::i64;
case ov::element::Type_t::boolean:
return cldnn::data_types::i8;
case ov::element::Type_t::u1:
return cldnn::data_types::bin;
default:
throw std::runtime_error("Can't convert " + t.get_type_name() + " element type");
}
}

inline ov::element::Type element_type_to_data_type(data_types t) {
switch (t) {
case cldnn::data_types::f32:
return ov::element::Type_t::f32;
case cldnn::data_types::f16:
return ov::element::Type_t::f16;
case cldnn::data_types::u8:
return ov::element::Type_t::u8;
case cldnn::data_types::i8:
return ov::element::Type_t::i8;
case cldnn::data_types::i32:
return ov::element::Type_t::i32;
case cldnn::data_types::i64:
return ov::element::Type_t::i64;
case cldnn::data_types::bin:
return ov::element::Type_t::u1;
default:
throw std::runtime_error("Can't convert " + data_type_traits::name(t) + " precision");
}
}

/// Helper function to get both data_types and format::type in a single, unique value. Useable in 'case' statement.
constexpr auto fuse(data_types dt, cldnn::format::type fmt) -> decltype(static_cast<std::underlying_type<data_types>::type>(dt) |
static_cast<std::underlying_type<format::type>::type>(fmt)) {
Expand Down Expand Up @@ -425,6 +477,9 @@ struct layout {

tensor get_tensor() const;

template<typename T>
T get() const;

void set_tensor(const tensor& size);

// Returns true if other layout can be reinterpreted without need of reordering
Expand Down
8 changes: 8 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/runtime/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "event.hpp"
#include "engine_configuration.hpp"

#include "ngraph/runtime/host_tensor.hpp"

#ifdef ENABLE_ONEDNN_FOR_GPU
#include <oneapi/dnnl/dnnl.hpp>
#endif
Expand Down Expand Up @@ -204,4 +206,10 @@ inline std::vector<T> read_vector(cldnn::memory::ptr mem, const cldnn::stream& s
return out_vecs;
}

inline std::shared_ptr<ngraph::runtime::HostTensor> make_host_tensor(layout l, void* memory_pointer) {
ov::element::Type et = element_type_to_data_type(l.data_type);

return std::make_shared<ngraph::runtime::HostTensor>(et, l.get_shape(), memory_pointer);
}

} // namespace cldnn
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ target_compile_options(${TARGET_NAME} PRIVATE
target_link_libraries(${TARGET_NAME} PUBLIC OpenCL)
target_link_libraries(${TARGET_NAME} PRIVATE openvino_intel_gpu_kernels
openvino_intel_gpu_runtime
ov_shape_inference
openvino::itt
openvino::runtime::dev
openvino::runtime)
Expand Down
34 changes: 34 additions & 0 deletions src/plugins/intel_gpu/src/graph/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "json_object.h"
#include <string>

#include "gather_shape_inference.hpp"

namespace cldnn {
primitive_type_id gather::type_id() {
static primitive_type_base<gather> instance;
Expand Down Expand Up @@ -62,6 +64,38 @@ layout gather_inst::calc_output_layout(gather_node const& node, kernel_impl_para
tensor(format::get_default_format(dims_converted.size()), dims_converted)};
}

template<typename ShapeType>
std::vector<layout> gather_inst::calc_output_layouts(gather_node const& node, const kernel_impl_params& impl_param) {
auto desc = node.get_primitive();

auto input0_layout = impl_param.get_input_layout(0);
auto input1_layout = impl_param.get_input_layout(1);

auto output_type = input0_layout.data_type;
if (node.has_fused_primitives()) {
output_type = node.get_fused_output_layout().data_type;
}

ov::op::v8::Gather op;
op.set_batch_dims(desc->batch_dim);
std::vector<ShapeType> output_shapes = {ShapeType()};
std::vector<ShapeType> input_shapes = {
input0_layout.get<ShapeType>(),
input1_layout.get<ShapeType>(),
ShapeType{1} // axis input is removed on gather primitive creation, so we can't use get_dependency(2)
};

int64_t axis = desc->axis;

auto axis_tensor = std::make_shared<ngraph::runtime::HostTensor>(ov::element::i64, ov::Shape{1}, static_cast<void*>(&axis));
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> const_data = {{2, axis_tensor}};
ov::op::util::shape_infer(&op, input_shapes, output_shapes, const_data);

format output_format = format::adjust_to_rank(input0_layout.format, output_shapes[0].size());

return { layout{output_shapes[0], output_type, output_format} };
}

std::string gather_inst::to_string(gather_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,9 @@ struct convolution_onednn : typed_primitive_onednn_impl<convolution, dnnl::convo
cldnn::format out_fmt = onednn::find_format(pd.weights_desc(0), grouped_weights);
kernel_selector::WeightsLayout reqLayout = to_weights_layout(out_fmt, cldnn_prim->grouped_weights_shape);

const auto& param_info = kernel_impl_params(arg.get_program(), cldnn_prim, arg.get_unique_id(),
arg.get_input_layouts(), arg.get_output_layout(),
arg.get_fused_primitives(),
arg.get_fused_activations_funcs(), arg.get_fused_activations_params(),
optional_layout(weights_layout),
arg.bias_term() ? optional_layout(arg.bias().get_output_layout()) : optional_layout(),
arg.weights_zero_points_term() ? optional_layout(arg.weights_zero_points().get_output_layout())
: optional_layout(),
arg.activations_zero_points_term() ? optional_layout(arg.activations_zero_points().get_output_layout())
: optional_layout(),
arg.compensation_term() ? optional_layout(arg.compensation().get_output_layout())
: optional_layout());

set_params(param_info, r_params);
const auto& param_info = arg.get_kernel_impl_params();

set_params(*param_info, r_params);
r_params.layerID = arg.id() + "_reorder_";
r_params.input = convert_weights_tensor(weights_layout, cldnn_prim->grouped_weights_shape);
r_params.output = r_params.input.TransformIgnorePadding(reqLayout, r_params.input.GetDType(), arg.get_groups(), false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,9 @@ struct deconvolution_onednn : typed_primitive_onednn_impl<deconvolution, dnnl::d
cldnn::format out_fmt = onednn::find_format(pd.weights_desc(0), grouped_weights);
kernel_selector::WeightsLayout reqLayout = to_weights_layout(out_fmt, cldnn_prim->grouped_weights_shape);

const auto& param_info = kernel_impl_params(arg.get_program(), cldnn_prim, arg.get_unique_id(),
arg.get_input_layouts(), arg.get_output_layout(),
arg.get_fused_primitives(),
arg.get_fused_activations_funcs(), arg.get_fused_activations_params(),
optional_layout(weights_layout),
arg.bias_term() ? optional_layout(arg.bias().get_output_layout()) : optional_layout());

set_params(param_info, r_params);
const auto& param_info = arg.get_kernel_impl_params();

set_params(*param_info, r_params);
r_params.layerID = arg.id() + "_reorder_";
r_params.input = convert_weights_tensor(weights_layout, cldnn_prim->grouped_weights_shape);
r_params.output = r_params.input.TransformIgnorePadding(reqLayout, r_params.input.GetDType(), arg.get_groups(), false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected, dnn
static kernel_selector::WeightsReorderParams get_weights_reorder(const fully_connected_node& arg, const dnnl::primitive_desc& pd) {
auto weights_layout = arg.get_dependency(1).get_output_layout();
auto cldnn_prim = arg.get_primitive();
const auto& param_info = kernel_impl_params(arg.get_program(), cldnn_prim, arg.get_unique_id(),
arg.get_input_layouts(), arg.get_output_layout(),
arg.get_fused_primitives(),
arg.get_fused_activations_funcs(), arg.get_fused_activations_params(),
optional_layout(weights_layout),
arg.bias_term() ? optional_layout(arg.bias().get_output_layout()) : optional_layout());
auto param_info = arg.get_kernel_impl_params();

kernel_selector::WeightsReorderParams weights_reorder_params;
auto& reorderKS = kernel_selector::ReorderWeightsKernelSelctor::Instance();
Expand All @@ -75,7 +70,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected, dnn
kernel_selector::WeightsLayout req_layout = to_weights_layout(out_fmt, false);

// set engine info & forcing
set_params(param_info, r_params);
set_params(*param_info, r_params);
r_params.layerID = arg.id() + "_reorder_";
r_params.input = convert_weights_tensor(weights_layout, false);
r_params.output = r_params.input.TransformIgnorePadding(req_layout, r_params.input.GetDType(), 1, false);
Expand Down
Loading

0 comments on commit b33f22c

Please sign in to comment.