Skip to content

Commit

Permalink
Refactor StaticShape conversion to IShapeInfer::Result
Browse files Browse the repository at this point in the history
Signed-off-by: Raasz, Pawel <[email protected]>
  • Loading branch information
praasz committed Dec 4, 2024
1 parent 3f28bef commit 78a5ab0
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,7 @@ class ShapeInferBase : public IStaticShapeInfer {

// call shape inference API
auto shape_infer_result = infer(input_static_shapes, MemoryAccessor(data_dependency, input_ranks));

Result result{{}, shape_infer_result ? ShapeInferStatus::success : ShapeInferStatus::skip};

if (shape_infer_result) {
result.dims.reserve(shape_infer_result->size());
std::transform(shape_infer_result->begin(),
shape_infer_result->end(),
std::back_inserter(result.dims),
[](StaticShape& s) {
return std::move(*s);
});
}

return result;
return shape_infer_result ? move_shapes_to_result(*shape_infer_result) : Result{{}, ShapeInferStatus::skip};
}

const ov::CoordinateDiff& get_pads_begin() override {
Expand All @@ -200,6 +187,15 @@ class ShapeInferBase : public IStaticShapeInfer {
protected:
std::vector<int64_t> m_input_ranks;
std::shared_ptr<ov::Node> m_node;

private:
static Result move_shapes_to_result(std::vector<StaticShape>& output_shapes) {
Result result{decltype(Result::dims){output_shapes.size()}, ShapeInferStatus::success};
std::transform(output_shapes.begin(), output_shapes.end(), result.dims.begin(), [](StaticShape& s) {
return std::move(*s);
});
return result;
}
};

/**
Expand Down

0 comments on commit 78a5ab0

Please sign in to comment.