From 275f8549488f1d35abcad9258feb714538769b45 Mon Sep 17 00:00:00 2001 From: Tim Gu Date: Wed, 16 Oct 2024 12:10:48 -0400 Subject: [PATCH] Fixed logic error of instruction::replace() (#3521) --- src/instruction.cpp | 13 +++++++++++-- test/instruction.cpp | 21 ++++++++++++++++++++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/instruction.cpp b/src/instruction.cpp index 8ab75f0f178..47bea70379e 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -26,6 +26,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -62,10 +63,18 @@ void instruction::replace(const shape& r) if(r != result) { result = r; - for(auto&& ins : output) + std::deque q(output.begin(), output.end()); + while(not q.empty()) { + instruction_ref ins = q.front(); + q.pop_front(); assert(ins->name() == "@return" or ins->name().front() != '@'); - ins->recompute_shape(); + shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args); + if(new_r != ins->result) + { + ins->result = new_r; + std::copy(ins->output.begin(), ins->output.end(), std::back_inserter(q)); + } } } } diff --git a/test/instruction.cpp b/test/instruction.cpp index 7b235cc09bf..134658e336b 100644 --- a/test/instruction.cpp +++ b/test/instruction.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -48,4 +48,23 @@ TEST_CASE(check_undefined) EXPECT(not mul->is_undefined()); } +TEST_CASE(check_replace_shape) +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {3, 2}}; + auto input = m.add_parameter("x", s); + auto reduce = m.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), input); + auto abs = m.add_instruction(migraphx::make_op("abs"), reduce); + auto sin = m.add_instruction(migraphx::make_op("sin"), reduce); + auto add = m.add_instruction(migraphx::make_op("add"), abs, sin); + + reduce->replace(migraphx::make_op("reduce_sum", {{"axes", {1}}})); + + migraphx::shape r{migraphx::shape::float_type, {3, 1}}; + EXPECT(reduce->get_shape() == r); + EXPECT(abs->get_shape() == r); + EXPECT(sin->get_shape() == r); + EXPECT(add->get_shape() == r); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }