Skip to content

Commit

Permalink
timing added in find_matches_for with env var
Browse files Browse the repository at this point in the history
  • Loading branch information
aarushjain29 committed Oct 23, 2024
1 parent 54657c7 commit 99414ee
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <array>
#include <unordered_map>
#include <unordered_set>
#include <migraphx/time.hpp>

#ifndef MIGRAPHX_USE_TYPE_ERASED_MATCHERS
#define MIGRAPHX_USE_TYPE_ERASED_MATCHERS 0
Expand Down Expand Up @@ -395,6 +396,7 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS)

/// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms>
Expand All @@ -403,7 +405,9 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M
const int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{});
const bool time_matchers = enabled(MIGRAPHX_TIME_MATCHERS{});
bool match = false;

each_args(
[&](auto&& m) {
const auto& matcher_name = get_type_name(m);
Expand All @@ -413,31 +417,39 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M
contains(matcher_name, trace_filter));
if(match)
return;
if(trace > 1 and trace_for)
std::cout << "Match: " << matcher_name << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return;
if(trace > 0 or trace_for)
{
std::cout << "Matched by " << matcher_name << std::endl;
get_module(mod).debug_print(ins);
}
// If its already invalid dont validate it again
bool invalidated = validate and get_module(mod).validate() != get_module(mod).end();
m.apply(mod, r);
if(validate and not invalidated)
{
auto invalid = get_module(mod).validate();
if(invalid != get_module(mod).end())

auto elapsed_time = time<std::chrono::nanoseconds>([&] {

if(trace > 1 and trace_for)
std::cout << "Match: " << matcher_name << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return;
if(trace > 0 or trace_for)
{
std::cout << "Invalid program from match: " << matcher_name << std::endl;
std::cout << "Invalid instructions: " << std::endl;
get_module(mod).debug_print(invalid->inputs());
get_module(mod).debug_print(invalid);
std::cout << "Matched by " << matcher_name << std::endl;
get_module(mod).debug_print(ins);
}
// If its already invalid dont validate it again
bool invalidated = validate and get_module(mod).validate() != get_module(mod).end();
m.apply(mod, r);
if(validate and not invalidated)
{
auto invalid = get_module(mod).validate();
if(invalid != get_module(mod).end())
{
std::cout << "Invalid program from match: " << matcher_name << std::endl;
std::cout << "Invalid instructions: " << std::endl;
get_module(mod).debug_print(invalid->inputs());
get_module(mod).debug_print(invalid);
}
}
match = true;
});
if(time_matchers)
{
std::cout << "Matcher " << matcher_name << " took " << elapsed_time << "ns." << std::endl;
}
match = true;
},
ms...);
}
Expand Down

0 comments on commit 99414ee

Please sign in to comment.