diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index a6d6b27d8186..5810f1b5337f 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -204,6 +204,7 @@ class RemoveUnusedVars : public ExprMutator { do { prev_size = unused.size(); + std::vector users_keys; for (const auto& kv : users) { // var -> [users...] // var is unused iff @@ -212,17 +213,21 @@ class RemoveUnusedVars : public ExprMutator { if (kv.second.empty() && // kv.first is not used by fn outputs. fn_outputs.end() == std::find(fn_outputs.begin(), fn_outputs.end(), kv.first)) { unused.push_back(kv.first); + } else { + users_keys.push_back(kv.first); } } for (size_t i = prev_size; i < unused.size(); ++i) { users.erase(unused[i]); // remove def site. - for (auto kv : users) { // remove use site. - auto it = std::find(kv.second.begin(), kv.second.end(), unused[i]); - if (it != kv.second.end()) { - kv.second.erase(it); - users.Set(kv.first, std::move(kv.second)); + for (const auto& key: users_keys) { // remove use site. + ICHECK(users.count(key)) << "the key " << key << " is expected to be in the mapping users."; + Array cur_users = users[key]; + auto it = std::find(cur_users.begin(), cur_users.end(), unused[i]); + if (it != cur_users.end()) { + cur_users.erase(it); + users.Set(key, std::move(cur_users)); } } }