diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 5ee2d66c878a..fa0326edb95b 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -38,8 +38,21 @@ def __init__(self, mod): self._set_input = self.mod["set_input"] self._reset = self.mod["reset"] - def get_stat(self): - return self._get_stat() + def get_stat(self, sort_by_time=True): + """Get the statistics of executed ops. + + Parameters + ---------- + sort_by_time: Optional[Boolean] + Set to indicate the returned results are sorted by execution time in + the descending order. It is printed in the random order if this + field is not set. + + Returns + ------- + The execution statistics in string. + """ + return self._get_stat(sort_by_time) def reset(self): self._reset() diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index b004f67e35d2..3b7b7aa8e73e 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include "vm.h" @@ -43,16 +44,32 @@ PackedFunc VirtualMachineDebug::GetFunction( const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "get_stat") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 1U); + std::vector> op_acc_time; + for (auto kv : op_durations_) { + auto val = std::make_pair( + kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0)); + op_acc_time.push_back(val); + } + bool sort_by_time = args[0]; + if (sort_by_time) { + auto comp = [](const std::pair& lhs, + const std::pair& rhs) { + return lhs.second > rhs.second; + }; + std::sort(op_acc_time.begin(), op_acc_time.end(), comp); + } double total_duration = 0.0; + int64_t total_packed_funcs = 0; std::ostringstream os; os << std::setw(30) << std::left << "#OpName" << "\t" << std::setw(10) << std::left << "#InvokeCount" << "\t" << "#Duration(us): Sum/Mean/Min/Max" << std::endl; - for (auto kv : op_durations_) { + for (auto kv : op_acc_time) { auto vals = op_durations_[kv.first]; - auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);; + auto sum = kv.second; auto mean = sum / static_cast(vals.size()); auto min_value = *std::min_element(vals.begin(), vals.end()); auto max_value = *std::max_element(vals.begin(), vals.end()); @@ -62,8 +79,10 @@ PackedFunc VirtualMachineDebug::GetFunction( << sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl; total_duration += sum; + total_packed_funcs += op_invokes_[kv.first]; } - os << "Total Duration " << total_duration << " us" << std::endl; + os << "\nTotal Duration: " << total_duration << " us.\t" + << "Total Packed Functions: " << total_packed_funcs << std::endl; *rv = os.str(); }); } else if (name == "reset") { diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index 6cfe6e83cc4a..b7bbe2f64c96 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -35,6 +35,7 @@ def test_basic(): data = np.random.rand(1, 3, 224, 224).astype('float32') res = vm.invoke("main", [data]) print("\n{}".format(vm.get_stat())) + print("\n{}".format(vm.get_stat(False))) if __name__ == "__main__": test_basic()