Skip to content

Commit

Permalink
Merge branch 'develop' into dump-mlir
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Oct 16, 2024
2 parents 5074563 + 275f854 commit fff0dc9
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
13 changes: 11 additions & 2 deletions src/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <migraphx/erase.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <deque>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -62,10 +63,18 @@ void instruction::replace(const shape& r)
if(r != result)
{
result = r;
for(auto&& ins : output)
std::deque<instruction_ref> q(output.begin(), output.end());
while(not q.empty())
{
instruction_ref ins = q.front();
q.pop_front();
assert(ins->name() == "@return" or ins->name().front() != '@');
ins->recompute_shape();
shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
if(new_r != ins->result)
{
ins->result = new_r;
std::copy(ins->output.begin(), ins->output.end(), std::back_inserter(q));
}
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion src/targets/gpu/compile_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,15 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};

std::string src = " ";
src_file input{"main.cpp", src};
std::vector<src_file> srcs = {input};

try
{
prog.compile(flags, true);
std::string arch = "gfx900";
compile_hip_src(srcs, flags, arch);
return true;
}
catch(...)
Expand Down
21 changes: 20 additions & 1 deletion test/instruction.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -48,4 +48,23 @@ TEST_CASE(check_undefined)
EXPECT(not mul->is_undefined());
}

TEST_CASE(check_replace_shape)
{
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {3, 2}};
auto input = m.add_parameter("x", s);
auto reduce = m.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), input);
auto abs = m.add_instruction(migraphx::make_op("abs"), reduce);
auto sin = m.add_instruction(migraphx::make_op("sin"), reduce);
auto add = m.add_instruction(migraphx::make_op("add"), abs, sin);

reduce->replace(migraphx::make_op("reduce_sum", {{"axes", {1}}}));

migraphx::shape r{migraphx::shape::float_type, {3, 1}};
EXPECT(reduce->get_shape() == r);
EXPECT(abs->get_shape() == r);
EXPECT(sin->get_shape() == r);
EXPECT(add->get_shape() == r);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
2 changes: 1 addition & 1 deletion test/onnx/.onnxrt-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
f25f3868a75d4422cde0090abc2781a5277e8eee
3c80aa9feed4ee1249fdcadfd50ad66b8e039ac1

0 comments on commit fff0dc9

Please sign in to comment.