Skip to content

Commit

Permalink
Sort VM stats by time (apache#4601)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 2, 2020
1 parent 00eb18c commit 6f9b5b5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
17 changes: 15 additions & 2 deletions python/tvm/relay/backend/profiler_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,21 @@ def __init__(self, mod):
self._set_input = self.mod["set_input"]
self._reset = self.mod["reset"]

def get_stat(self):
return self._get_stat()
def get_stat(self, sort_by_time=True):
"""Get the statistics of executed ops.
Parameters
----------
sort_by_time: Optional[Boolean]
Set to indicate the returned results are sorted by execution time in
the descending order. It is printed in the random order if this
field is not set.
Returns
-------
The execution statistics in string.
"""
return self._get_stat(sort_by_time)

def reset(self):
self._reset()
25 changes: 22 additions & 3 deletions src/runtime/vm/profiler/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>

#include "vm.h"
Expand All @@ -43,16 +44,32 @@ PackedFunc VirtualMachineDebug::GetFunction(
const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_stat") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1U);
std::vector<std::pair<Index, double>> op_acc_time;
for (auto kv : op_durations_) {
auto val = std::make_pair(
kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0));
op_acc_time.push_back(val);
}
bool sort_by_time = args[0];
if (sort_by_time) {
auto comp = [](const std::pair<Index, double>& lhs,
const std::pair<Index, double>& rhs) {
return lhs.second > rhs.second;
};
std::sort(op_acc_time.begin(), op_acc_time.end(), comp);
}
double total_duration = 0.0;
int64_t total_packed_funcs = 0;
std::ostringstream os;
os << std::setw(30) << std::left << "#OpName"
<< "\t" << std::setw(10) << std::left << "#InvokeCount"
<< "\t"
<< "#Duration(us): Sum/Mean/Min/Max" << std::endl;

for (auto kv : op_durations_) {
for (auto kv : op_acc_time) {
auto vals = op_durations_[kv.first];
auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);;
auto sum = kv.second;
auto mean = sum / static_cast<double>(vals.size());
auto min_value = *std::min_element(vals.begin(), vals.end());
auto max_value = *std::max_element(vals.begin(), vals.end());
Expand All @@ -62,8 +79,10 @@ PackedFunc VirtualMachineDebug::GetFunction(
<< sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl;

total_duration += sum;
total_packed_funcs += op_invokes_[kv.first];
}
os << "Total Duration " << total_duration << " us" << std::endl;
os << "\nTotal Duration: " << total_duration << " us.\t"
<< "Total Packed Functions: " << total_packed_funcs << std::endl;
*rv = os.str();
});
} else if (name == "reset") {
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_runtime_vm_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_basic():
data = np.random.rand(1, 3, 224, 224).astype('float32')
res = vm.invoke("main", [data])
print("\n{}".format(vm.get_stat()))
print("\n{}".format(vm.get_stat(False)))

if __name__ == "__main__":
test_basic()

0 comments on commit 6f9b5b5

Please sign in to comment.