From 8408ac0f81a0a55c3875064c1fbf3c1eec52e785 Mon Sep 17 00:00:00 2001 From: farshidsp Date: Thu, 23 Mar 2023 09:19:29 -0700 Subject: [PATCH 1/4] not estimating the flops when there is a default estimated flops as attr --- src/tir/analysis/estimate_flops.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index d158a001b2d8..a41adbed54f8 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -208,10 +208,16 @@ double EstimateTIRFlops(const Stmt& stmt) { double EstimateTIRFlops(const IRModule& mod) { FlopEstimator counter; TResult result; - VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) { - result += counter.VisitStmt(f->body); // + double cached_result = 0; + VisitPrimFuncs(mod, [&result, &counter, &cached_result](const PrimFuncNode* f) { + if (auto cached = f->attrs.GetAttr("estimated_flops")) { + cached_result += cached.value()->value; + } + else { + result += counter.VisitStmt(f->body); // + } }); - return PostprocessResults(result); + return PostprocessResults(result) + cached_result; } TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double { From fd1cd05d25038b5c0867a21649e0ce83c1f229ad Mon Sep 17 00:00:00 2001 From: farshidsp Date: Thu, 23 Mar 2023 10:17:01 -0700 Subject: [PATCH 2/4] add unittests --- .../test_tir_analysis_estimate_tir_flops.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py index 06f6fe31278d..7b1457854397 100644 --- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py +++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py @@ -77,5 +77,39 @@ def test_flops_with_if(): assert flops == 16 +@T.prim_func +def flops_with_forloop_as_expression(A: T.Buffer(1)): + for i in T.serial(0, 16): + for k in T.serial(0, i): + A[0] = A[0] + 1 + + +@T.prim_func +def flops_override(a: T.Buffer(16, "float32"), b: T.Buffer(16, "float32")): + T.func_attr({"estimated_flops": 32}) + for i in range(16): + if i % 2 == 0: + a[i] = b[i] + else: + if i % 3 == 0: + a[i] = b[i - 1] + b[i - 2] + + +def test_estimate_flops_forloop_as_experssion(): + flops = estimate_tir_flops( + IRModule({"main": flops_with_forloop_as_expression.with_attr("estimated_flops", 32)}) + ) + assert flops == 32 + + # test whether the user estimated flop would over ride + flops = estimate_tir_flops(IRModule({"main": flops_override})) + assert flops == 32 + + +def test_exception(): + with pytest.raises(tvm.TVMError): + flops = estimate_tir_flops(IRModule({"main": flops_with_forloop_as_expression})) + + if __name__ == "__main__": tvm.testing.main() From 3593e5f88e090e3be608fef806c17d59b2c529fa Mon Sep 17 00:00:00 2001 From: farshidsp Date: Thu, 23 Mar 2023 10:45:37 -0700 Subject: [PATCH 3/4] lint fix --- src/tir/analysis/estimate_flops.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index a41adbed54f8..ce9d5eaaf838 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -212,8 +212,7 @@ double EstimateTIRFlops(const IRModule& mod) { VisitPrimFuncs(mod, [&result, &counter, &cached_result](const PrimFuncNode* f) { if (auto cached = f->attrs.GetAttr("estimated_flops")) { cached_result += cached.value()->value; - } - else { + } else { result += counter.VisitStmt(f->body); // } }); From 7db168552b97dd9d25953153b5c1081cf80ddc82 Mon Sep 17 00:00:00 2001 From: farshidsp Date: Thu, 23 Mar 2023 13:58:56 -0700 Subject: [PATCH 4/4] make unittest simpler --- .../unittest/test_tir_analysis_estimate_tir_flops.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py index 7b1457854397..489db287f377 100644 --- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py +++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py @@ -85,14 +85,10 @@ def flops_with_forloop_as_expression(A: T.Buffer(1)): @T.prim_func -def flops_override(a: T.Buffer(16, "float32"), b: T.Buffer(16, "float32")): +def flops_override(A: T.Buffer(16, "float32")): T.func_attr({"estimated_flops": 32}) for i in range(16): - if i % 2 == 0: - a[i] = b[i] - else: - if i % 3 == 0: - a[i] = b[i - 1] + b[i - 2] + A[0] = A[0] + 1 def test_estimate_flops_forloop_as_experssion():