From 55c45386440afc5df87c20b07c6d41bf8a687bfd Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 4 Aug 2020 04:38:19 -0700 Subject: [PATCH] improve flop estimation --- src/auto_scheduler/compute_dag.cc | 28 ++++++++++++------- src/auto_scheduler/feature.cc | 17 +++-------- .../test_auto_scheduler_compute_dag.py | 10 +++++-- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index b11dd7347504a..ceaf94fe04bb1 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -536,17 +536,25 @@ class FlopEstimator : public ExprFunctor { double ret = 0; for (const auto& op : ops) { if (auto pop = op.as()) { - double num_element = AxisLengthProd(pop->axis); - if (num_element == -1) { - fail_ = true; - break; - } - cur_type_code_ = pop->output_dtype(0).code(); - double op_per_element = 0; - for (const auto& x : pop->body) { - op_per_element += VisitExpr(x); + if (pop->attrs.count("FLOP")) { + // Use user-provided FLOP + auto pint = pop->attrs["FLOP"].as(); + CHECK(pint != nullptr); + ret += pint->value; + } else { + // Estimate by parsing the compute body + double num_element = AxisLengthProd(pop->axis); + if (num_element == -1) { + fail_ = true; + break; + } + cur_type_code_ = pop->output_dtype(0).code(); + double op_per_element = 0; + for (const auto& x : pop->body) { + op_per_element += VisitExpr(x); + } + ret += num_element * op_per_element; } - ret += num_element * op_per_element; } else if (op->IsInstance()) { {} // do nothing } else { diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 9408f6e52865b..a2626ddba2254 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -75,12 +75,7 @@ enum class AnnotationPosType : int { }; // Buffer access type -enum class BufferAccessType : int { - kRead = 0, - kWrite = 1, - kReadWrite = 2, - kUnknownRW = 3 -}; +enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 }; // Accesses to a buffer struct BufferAccess { @@ -89,11 +84,7 @@ struct BufferAccess { }; // Data reuse type -enum class ReuseType : int { - kLoopMultipleRead = 0, - kSerialMultipleReadWrite = 1, - kNoReuse = 2 -}; +enum class ReuseType : int { kLoopMultipleRead = 0, kSerialMultipleReadWrite = 1, kNoReuse = 2 }; // Feature for an access of a buffer struct BufferAccessFeature { @@ -1514,8 +1505,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromFile") std::vector normalized_throughputs; std::vector task_ids; - GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features, &normalized_throughputs, - &task_ids); + GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features, + &normalized_throughputs, &task_ids); std::vector byte_data; *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index 2530d554e8eeb..6b76dc607da97 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -44,13 +44,17 @@ def test_estimate_flop(): D = topi.nn.relu(C) dag = auto_scheduler.ComputeDAG([A, B, D]) - assert abs(dag.flop_ct - 2 * N ** 3 - N * N) < 0.5 + assert abs(dag.flop_ct - (2 * N ** 3 + N * N)) < 0.5 # should not count the comparison operations in padding - D = topi.nn.pad(C, [1, 1]) - dag = auto_scheduler.ComputeDAG([A, B, D]) + E = topi.nn.pad(C, [1, 1]) + dag = auto_scheduler.ComputeDAG([A, B, E]) assert abs(dag.flop_ct - 2 * N ** 3) < 0.5 + F = te.compute((N, N), lambda i, j: E[i,j], name='F', attrs={"FLOP": 1234}) + dag = auto_scheduler.ComputeDAG([A, B, F]) + assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5 + if __name__ == "__main__": test_apply_steps()