From c1d2d5afdeec5d1d352bca8d33a9ae0d8f922c49 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Mon, 21 Oct 2024 11:29:58 -0500 Subject: [PATCH] Handle more reshapes with the shape transform descriptor (#3527) --- .../migraphx/shape_transform_descriptor.hpp | 1 + src/shape_transform_descriptor.cpp | 14 ++++++++++++++ test/shape_transform_descriptor.cpp | 9 +++++++++ 3 files changed, 24 insertions(+) diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index a8851de753f..4160759d47d 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -76,6 +76,7 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor bool apply(const std::vector& ops); bool apply_reshape(const std::vector& rdims); + bool apply_reshape_impl(const std::vector& rdims); bool apply_transpose(const std::vector& permutation); bool apply_broadcast(const std::vector& out_lens, optional axis = nullopt); diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index e336674d660..51802ef0fae 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -143,8 +143,22 @@ bool shape_transform_descriptor::apply(const std::vector& ops) return true; } bool shape_transform_descriptor::apply_reshape(const std::vector& rdims) +{ + std::vector idims; + transform(get_all_subdimensions(dimensions), + std::back_inserter(idims), + std::mem_fn(&dimension::sub::len)); + auto cdims = common_dims::compute(idims, rdims).dims; + if(not cdims.empty() and not apply_reshape_impl(cdims)) + return false; + return apply_reshape_impl(rdims); +} +bool shape_transform_descriptor::apply_reshape_impl(const std::vector& rdims) { assert(migraphx::elements(rdims) == this->elements()); + if(migraphx::equal( + dimensions, rdims, [](const dimension& d, std::size_t rdim) { return d.len() == rdim; })) + return true; std::vector new_dims; auto subs = get_all_subdimensions(dimensions); std::size_t i = 0; diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index ccb7cbb98be..8a5b7cf34f2 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -134,6 +134,15 @@ TEST_CASE(record_reshape_split) EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1, 0}}, d_axes{{1, 1}}}); } +TEST_CASE(record_reshape_merge_split) +{ + auto desc = make_descriptor({3, 10, 16}, make_op("reshape", {{"dims", {3, 40, 2, 2}}})); + EXPECT(get_final_lens(desc) == final_lens{3, 40, 2, 2}); + EXPECT(get_all_lens(desc) == all_lens{{3}, {10, 4}, {2}, {2}}); + EXPECT(get_all_axes(desc) == + all_axes{d_axes{{0}}, d_axes{{1}, {2, 0}}, d_axes{{2, 1}}, d_axes{{2, 2}}}); +} + TEST_CASE(record_squeeze_trailing_1s) { auto desc = make_descriptor({3, 4, 4, 1, 1}, make_op("reshape", {{"dims", {3, 4, 4}}}));