From 4c40cb06c103a850e0f8d8f9c5131c4c173e6ded Mon Sep 17 00:00:00 2001 From: libing4752 Date: Thu, 14 Jun 2018 01:53:04 +0800 Subject: [PATCH] fix copro_sync.cc errors of ctx (#1274) --- src/pass/coproc_sync.cc | 2 +- .../python/unittest/test_pass_storage_sync.py | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index 28be8aba2057..b3e64a989702 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -385,7 +385,7 @@ class CoProcInstDepDetector : public IRVisitor { &(curr_state_.exit_push), &(curr_state_.enter_pop)); curr_state_.enter_ctx = first_state_.enter_ctx; - curr_state_.exit_ctx = last_state_.enter_ctx; + curr_state_.exit_ctx = last_state_.exit_ctx; } std::swap(first_state_, temp_first); std::swap(last_state_, temp_last); diff --git a/tests/python/unittest/test_pass_storage_sync.py b/tests/python/unittest/test_pass_storage_sync.py index ce9e2f9a4af9..2286dd53e981 100644 --- a/tests/python/unittest/test_pass_storage_sync.py +++ b/tests/python/unittest/test_pass_storage_sync.py @@ -78,7 +78,44 @@ def test_coproc_sync2(): stmt = ib.get() stmt = tvm.ir_pass.CoProcSync(stmt) +def test_coproc_sync3(): + def __check_list(tvm_array, py_list): + for ti, li in zip(tvm_array, py_list): + if ti.value != li: + return False + return True + + ib = tvm.ir_builder.create() + n = tvm.var("n") + cp = tvm.thread_axis((0, 1), "cop") + A = ib.allocate("float32", 128, name="A", scope="global.cache") + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, n, name="i") as j: + with ib.new_scope(): + ib.scope_attr(cp, "coproc_scope", 1) + A[i] = 1.0 + with ib.new_scope(): + ib.scope_attr(cp, "coproc_scope", 2) + A[i] = 1.0 + with ib.new_scope(): + ib.scope_attr(cp, "coproc_scope", 3) + A[0] = 0.0 + + stmt = ib.get() + stmt = tvm.ir_pass.CoProcSync(stmt) + slist = tvm.make.stmt_list(stmt.first.body.body) + push_st = slist[2] + slist = tvm.make.stmt_list(slist[-1]) + pop_st = slist[0].body.first + + assert(push_st.value.name == "cop.coproc_dep_push") + assert(__check_list(push_st.value.args, [2,3])) + assert(pop_st.value.name == "cop.coproc_dep_pop") + assert(__check_list(pop_st.value.args, [2,3])) + + if __name__ == "__main__": test_coproc_sync() test_storage_sync() test_coproc_sync2() + test_coproc_sync3()