Skip to content

Commit

Permalink
partial_shape serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
e-ddykim committed Nov 15, 2022
1 parent 0b6465f commit 1ee8426
Showing 1 changed file with 31 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,33 @@
#include "intel_gpu/runtime/layout.hpp"

namespace cldnn {
template <typename BufferType>
class Serializer<BufferType, ov::PartialShape, typename std::enable_if<std::is_base_of<OutputBuffer<BufferType>, BufferType>::value>::type> {
public:
static void save(BufferType& buffer, const ov::PartialShape& partial_shape) {
std::vector<ov::Dimension> dimensions(partial_shape);
buffer << dimensions.size();
for (const auto& dimension : dimensions) {
buffer << dimension.get_interval().get_min_val();
buffer << dimension.get_interval().get_max_val();
}
}
};

template <typename BufferType>
class Serializer<BufferType, ov::PartialShape, typename std::enable_if<std::is_base_of<InputBuffer<BufferType>, BufferType>::value>::type> {
public:
static void load(BufferType& buffer, ov::PartialShape& partial_shape) {
size_t num_dimensions;
buffer >> num_dimensions;
for (size_t i = 0; i < num_dimensions; i++) {
ov::Dimension::value_type min_val, max_val;
buffer >> min_val >> max_val;
partial_shape.push_back(ov::Dimension(min_val, max_val));
}
}
};

template <typename BufferType>
class Serializer<BufferType, cldnn::layout, typename std::enable_if<std::is_base_of<OutputBuffer<BufferType>, BufferType>::value>::type> {
public:
Expand All @@ -21,15 +48,7 @@ class Serializer<BufferType, cldnn::layout, typename std::enable_if<std::is_base
buffer << _layout.data_padding.filling_value();
buffer << _layout.data_padding.lower_size().sizes();
buffer << _layout.data_padding.upper_size().sizes();

std::vector<cldnn::tensor::value_type> _sizes = _layout.get_tensor().sizes(_layout.format);
// Temp WA for bs_x_bsv16
if (_layout.format == cldnn::format::bs_x_bsv16) {
std::vector<cldnn::tensor::value_type> _tmp_sizes = _layout.get_tensor().sizes();
_sizes[0] = _tmp_sizes[0];
_sizes[1] = _tmp_sizes[1];
}
buffer << _sizes;
buffer << _layout.get_partial_shape();
}
};

Expand All @@ -50,15 +69,9 @@ class Serializer<BufferType, cldnn::layout, typename std::enable_if<std::is_base
_layout.data_padding = cldnn::padding(_lower_size, _upper_size, _filling_value);
}

std::vector<cldnn::tensor::value_type> _sizes;
buffer >> _sizes;

// Temp WA for bs_x_bsv16
if (_layout.format == cldnn::format::bs_x_bsv16) {
_layout.set_tensor(tensor(_sizes));
} else {
_layout.set_tensor(tensor(_layout.format, _sizes));
}
ov::PartialShape partial_shape;
buffer >> partial_shape;
_layout.set_partial_shape(partial_shape);
}
};

Expand Down

0 comments on commit 1ee8426

Please sign in to comment.