Skip to content

Commit

Permalink
fix copro_sync.cc errors of ctx (#1274)
Browse files Browse the repository at this point in the history
  • Loading branch information
libing4752 authored and tqchen committed Jun 13, 2018
1 parent d8faa50 commit 4c40cb0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/pass/coproc_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ class CoProcInstDepDetector : public IRVisitor {
&(curr_state_.exit_push),
&(curr_state_.enter_pop));
curr_state_.enter_ctx = first_state_.enter_ctx;
curr_state_.exit_ctx = last_state_.enter_ctx;
curr_state_.exit_ctx = last_state_.exit_ctx;
}
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_pass_storage_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,44 @@ def test_coproc_sync2():
stmt = ib.get()
stmt = tvm.ir_pass.CoProcSync(stmt)

def test_coproc_sync3():
def __check_list(tvm_array, py_list):
for ti, li in zip(tvm_array, py_list):
if ti.value != li:
return False
return True

ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", 128, name="A", scope="global.cache")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, n, name="i") as j:
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 1)
A[i] = 1.0
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 2)
A[i] = 1.0
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 3)
A[0] = 0.0

stmt = ib.get()
stmt = tvm.ir_pass.CoProcSync(stmt)
slist = tvm.make.stmt_list(stmt.first.body.body)
push_st = slist[2]
slist = tvm.make.stmt_list(slist[-1])
pop_st = slist[0].body.first

assert(push_st.value.name == "cop.coproc_dep_push")
assert(__check_list(push_st.value.args, [2,3]))
assert(pop_st.value.name == "cop.coproc_dep_pop")
assert(__check_list(pop_st.value.args, [2,3]))


if __name__ == "__main__":
test_coproc_sync()
test_storage_sync()
test_coproc_sync2()
test_coproc_sync3()

0 comments on commit 4c40cb0

Please sign in to comment.