Skip to content

Commit

Permalink
[Ref] Concat op reference implementation code improvements (#23048)
Browse files Browse the repository at this point in the history
### Details:
- Concat op reference implementation code improvements, removal of code
duplication
- Leftovers from PR:
#22686
- Reduce Concat template code, pass element_type to distinguish copy
method for string
 (element::Type_t::undefined as default to keep compatibility)
- ~Corresponding update of places where reference::concat is used
(including one gpu file)~ (reverted)
 
(*First approach
(6308f9d)
was to simply introduce common template to reuse the Concat reference
code, but it still results in generation of string and char
specializations of the whole template function*
*Current approach is to use the same function and choose the copy method
inside, based on the element type*)

### Tickets:
 -Related to 131838
  • Loading branch information
mitruska authored Mar 4, 2024
1 parent caf9148 commit e2a7495
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 63 deletions.
11 changes: 3 additions & 8 deletions src/core/reference/include/openvino/reference/concat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>

#include "openvino/core/shape.hpp"
#include "openvino/core/type/element_type.hpp"

namespace ov {
namespace reference {
Expand All @@ -15,14 +16,8 @@ void concat(const std::vector<const char*>& args,
const std::vector<Shape>& in_shapes,
const Shape& out_shape,
int64_t concatenation_axis,
size_t elem_size);

void concat(const std::vector<const std::string*>& args,
std::string* out,
const std::vector<Shape>& in_shapes,
const Shape& out_shape,
int64_t concatenation_axis,
size_t);
size_t elem_size,
const ov::element::Type& elem_type = ov::element::Type_t::undefined);

} // namespace reference
} // namespace ov
59 changes: 25 additions & 34 deletions src/core/reference/src/op/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,60 +17,51 @@ std::vector<size_t> calculate_shape_sizes(const std::vector<Shape>& in_shapes) {
});
return sizes;
}

void copy_elements(const char* arg,
char* out,
size_t in_offset,
size_t out_offset,
size_t num_of_elements,
size_t elem_size) {
std::memcpy(out + (out_offset * elem_size), arg + (in_offset * elem_size), num_of_elements * elem_size);
}

void copy_string_elements(const char* arg,
char* out,
size_t in_offset,
size_t out_offset,
size_t num_of_elements,
size_t) {
const auto src_begin = std::next(reinterpret_cast<const std::string*>(arg), in_offset);
const auto out_ptr = std::next(reinterpret_cast<std::string*>(out), out_offset);
std::copy_n(src_begin, num_of_elements, out_ptr);
}
} // namespace

void concat(const std::vector<const char*>& args,
char* out,
const std::vector<Shape>& in_shapes,
const Shape& out_shape,
int64_t concatenation_axis,
size_t elem_size) {
size_t steps = 1;
for (int i = 0; i < concatenation_axis; ++i) {
steps *= out_shape[i];
}

size_t elem_size,
const ov::element::Type& elem_type) {
const auto steps = shape_size(out_shape.begin(), out_shape.begin() + concatenation_axis);
const auto& shape_sizes = calculate_shape_sizes(in_shapes);

size_t out_offset = 0;
for (size_t step = 0; step < steps; ++step) {
for (size_t in_index = 0; in_index < args.size(); ++in_index) {
const size_t size = shape_sizes[in_index] / steps;
const size_t in_offset = step * size;

std::memcpy(&out[out_offset * elem_size], &args[in_index][in_offset * elem_size], size * elem_size);

out_offset += size;
}
}
}

void concat(const std::vector<const std::string*>& args,
std::string* out,
const std::vector<Shape>& in_shapes,
const Shape& out_shape,
int64_t concatenation_axis,
size_t) {
size_t steps = 1;
for (int i = 0; i < concatenation_axis; ++i) {
steps *= out_shape[i];
}
const auto& shape_sizes = calculate_shape_sizes(in_shapes);
const auto copy_func = elem_type == ov::element::string ? copy_string_elements : copy_elements;

size_t out_offset = 0;
for (size_t step = 0; step < steps; ++step) {
for (size_t in_index = 0; in_index < args.size(); ++in_index) {
const size_t size = shape_sizes[in_index] / steps;
const size_t in_offset = step * size;

const auto src_begin = std::next(args[in_index], in_offset);
const auto out_ptr = std::next(out, out_offset);
std::copy_n(src_begin, size, out_ptr);
copy_func(args[in_index], out, in_offset, out_offset, size, elem_size);

out_offset += size;
}
}
}

} // namespace reference
} // namespace ov
34 changes: 13 additions & 21 deletions src/core/src/op/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,43 +52,35 @@ std::shared_ptr<Node> Concat::clone_with_new_inputs(const OutputVector& new_args
return std::make_shared<Concat>(new_args, m_axis);
}

template <typename T>
void evaluate_concat(const Concat* node, TensorVector& outputs, const TensorVector& inputs) {
bool Concat::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v0_Concat_evaluate);
OPENVINO_ASSERT(outputs.size() == 1);

const auto inputs_count = inputs.size();
std::vector<Shape> arg_shapes;
std::vector<PartialShape> input_shapes;
std::vector<const char*> arg_bufs;
arg_shapes.reserve(inputs_count);
input_shapes.reserve(inputs_count);
arg_bufs.reserve(inputs_count);

std::vector<const T*> arg_bufs(inputs_count);
auto arg_buf = arg_bufs.begin();
for (auto& input : inputs) {
*arg_buf = static_cast<const T*>(input.data());
++arg_buf;
const auto& input_shape = input.get_shape();
arg_shapes.emplace_back(input_shape);
input_shapes.emplace_back(input_shape);
arg_bufs.emplace_back(static_cast<const char*>(input.data()));
}

const auto& out_shape = shape_infer(node, input_shapes).front().to_shape();
const auto& out_shape = shape_infer(this, input_shapes).front().to_shape();
outputs.front().set_shape(out_shape);
const auto elem_type = outputs.front().get_element_type();
reference::concat(arg_bufs,
static_cast<T*>(outputs.front().data()),
static_cast<char*>(outputs.front().data()),
arg_shapes,
out_shape,
ov::util::normalize(node->get_axis(), out_shape.size()),
outputs.front().get_element_type().size());
}

bool Concat::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v0_Concat_evaluate);
OPENVINO_ASSERT(outputs.size() == 1);

if (outputs.front().get_element_type() == ov::element::string) {
evaluate_concat<std::string>(this, outputs, inputs);
} else {
evaluate_concat<char>(this, outputs, inputs);
}
ov::util::normalize(this->get_axis(), out_shape.size()),
elem_type.size(),
elem_type);

return true;
}
Expand Down

0 comments on commit e2a7495

Please sign in to comment.