diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index d158a001b2d8..ce9d5eaaf838 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -208,10 +208,15 @@ double EstimateTIRFlops(const Stmt& stmt) { double EstimateTIRFlops(const IRModule& mod) { FlopEstimator counter; TResult result; - VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) { - result += counter.VisitStmt(f->body); // + double cached_result = 0; + VisitPrimFuncs(mod, [&result, &counter, &cached_result](const PrimFuncNode* f) { + if (auto cached = f->attrs.GetAttr("estimated_flops")) { + cached_result += cached.value()->value; + } else { + result += counter.VisitStmt(f->body); // + } }); - return PostprocessResults(result); + return PostprocessResults(result) + cached_result; } TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double { diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py index 06f6fe31278d..489db287f377 100644 --- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py +++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py @@ -77,5 +77,35 @@ def test_flops_with_if(): assert flops == 16 +@T.prim_func +def flops_with_forloop_as_expression(A: T.Buffer(1)): + for i in T.serial(0, 16): + for k in T.serial(0, i): + A[0] = A[0] + 1 + + +@T.prim_func +def flops_override(A: T.Buffer(16, "float32")): + T.func_attr({"estimated_flops": 32}) + for i in range(16): + A[0] = A[0] + 1 + + +def test_estimate_flops_forloop_as_experssion(): + flops = estimate_tir_flops( + IRModule({"main": flops_with_forloop_as_expression.with_attr("estimated_flops", 32)}) + ) + assert flops == 32 + + # test whether the user estimated flop would over ride + flops = estimate_tir_flops(IRModule({"main": flops_override})) + assert flops == 32 + + +def test_exception(): + with pytest.raises(tvm.TVMError): + flops = estimate_tir_flops(IRModule({"main": flops_with_forloop_as_expression})) + + if __name__ == "__main__": tvm.testing.main()