Skip to content

Commit

Permalink
Fix related dim bug(when a_outer = 0) (PaddlePaddle#391)
Browse files Browse the repository at this point in the history
* fix related dim bug(a_outer = 0)

Co-authored-by: wangone <[email protected]>
  • Loading branch information
haozech and wenming2014 authored Jun 2, 2021
1 parent 53bdc07 commit eeaba40
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 27 deletions.
14 changes: 7 additions & 7 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
```
___ ___ ___
/\__\ /\ \ /\ \ ______________
/\__\ /\ \ /\ \
/:/ / ___ \:\ \ \:\ \
/:/ / /\__\ \:\ \ \:\ \ ______________
/:/ / /\__\ \:\ \ \:\ \
/:/ / ___ /:/__/ _____\:\ \ _____\:\ \
/:/__/ /\__\/::\ \ /::::::::\__\/::::::::\__\ ______________
/:/__/ /\__\/::\ \ /::::::::\__\/::::::::\__\
\:\ \ /:/ /\/\:\ \__\:\~~\~~\/__/\:\~~\~~\/__/
\:\ /:/ / ~~\:\/\__\\:\ \ \:\ \ _________________
\:\/:/ / \::/ / \:\ \ \:\ \ ______________
\::/ / /:/ / \:\__\ \:\__\ ______________
\/__/ \/__/ \/__/ \/__/ ______________
\:\ /:/ / \:\/\__\\:\ \ \:\ \
\:\/:/ / \::/ / \:\ \ \:\ \
\::/ / /:/ / \:\__\ \:\__\
\/__/ \/__/ \/__/ \/__/
```

Expand Down
37 changes: 30 additions & 7 deletions cinn/poly/isl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <isl/cpp.h>

#include <algorithm>
#include <set>

#include "cinn/utils/string.h"

Expand Down Expand Up @@ -358,34 +359,56 @@ isl::map RemoveAxiesByOutputNames(const isl::map &x, const std::vector<std::stri

std::vector<std::string> GetRelatedOutputAxies(const isl::map &x, const std::vector<std::string> &dim_in_names) {
std::string map_str = isl_map_to_str(x.get());
isl::ctx this_ctx = x.ctx();
VLOG(1) << "GetRelatedOutputAxies map_str is : " << map_str;
isl::ctx this_ctx = x.ctx();
isl::map temp_transform(this_ctx, map_str);
auto dim_out_names = isl_get_dim_names(temp_transform, isl_dim_out);
std::set<std::string> dim_in_set;
for (auto &i : dim_in_names) {
temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_in, i.c_str()));
VLOG(1) << "GetRelatedOutputAxies dim_in_names is : " << i;
dim_in_set.insert(i);
}
std::string deleted_map = isl_map_to_str(temp_transform.get());
std::vector<std::string> res;
std::set<std::string> res_set;
for (auto &i : dim_out_names) {
if (utils::Count(&map_str, i) != utils::Count(&deleted_map, i)) {
res.push_back(i);
auto related_in_dim = GetRelatedInputAxies(temp_transform, {i});
for (auto &j : related_in_dim) {
if (dim_in_set.count(j) > 0) {
res_set.insert(i);
}
}
}
std::vector<std::string> res;
for (auto &i : res_set) {
VLOG(1) << "GetRelatedOutputAxies res is : " << i;
res.push_back(i);
}
return res;
}

std::vector<std::string> GetRelatedInputAxies(const isl::map &x, const std::vector<std::string> &dim_out_names) {
std::string map_str = isl_map_to_str(x.get());
isl::ctx this_ctx = x.ctx();
VLOG(1) << "GetRelatedInputAxies map_str is : " << map_str;
isl::ctx this_ctx = x.ctx();
isl::map temp_transform(this_ctx, map_str);
auto dim_in_names = isl_get_dim_names(temp_transform, isl_dim_in);
for (auto &i : dim_out_names) {
VLOG(1) << "GetRelatedInputAxies dim_out_names is : " << i;
temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str()));
}
std::string deleted_map = isl_map_to_str(temp_transform.get());
std::vector<std::string> res;
std::set<std::string> out_set;
for (auto &i : dim_out_names) {
if (utils::Endswith(i, "_inner") || utils::Endswith(i, "_outer")) {
out_set.insert(i);
}
}
for (auto &i : dim_in_names) {
if (utils::Count(&map_str, i) != utils::Count(&deleted_map, i)) {
VLOG(1) << "GetRelatedInputAxies res is : " << i;
res.push_back(i);
} else if (out_set.count(i + "_outer") > 0 || out_set.count(i + "_inner") > 0) {
VLOG(1) << "GetRelatedInputAxies res is : " << i;
res.push_back(i);
}
}
Expand Down
23 changes: 10 additions & 13 deletions cinn/poly/stage.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1052,20 +1052,17 @@ void Stage::CopyTransform(Stage *other, int level) {
isl::map temp_target_trans(this_ctx, str_target_trans);
if (level + 1 < isl_map_dim(temp_target_trans.get(), isl_dim_out)) {
std::string pivot_dim_out = isl_map_get_dim_name(temp_target_trans.get(), isl_dim_out, level + 1);
temp_target_trans = isl::manage(isl_map_remove_dims(temp_target_trans.release(), isl_dim_out, 0, level + 1));
std::string map_after_deletion = isl_map_to_str(temp_target_trans.get());

std::string pivot_dim_in;
for (int i = 0; i < target_map_dims.size(); i++) {
if (utils::Count(&map_after_deletion, target_map_dims[i]) > 1) {
pivot_dim_in = target_map_dims[i];
break;
}
std::vector<std::string> dim_out_level;
for (int i = 0; i <= level; i++) {
dim_out_level.push_back(isl_map_get_dim_name(temp_target_trans.get(), isl_dim_out, i));
}
if (utils::Count(&str_target_trans, pivot_dim_in) != utils::Count(&map_after_deletion, pivot_dim_in) ||
utils::Count(&str_target_trans, pivot_dim_out) != utils::Count(&map_after_deletion, pivot_dim_out)) {
this->CopyTransform(other, level + 1);
return;
auto related_dim_in = GetRelatedInputAxies(temp_target_trans, dim_out_level);
auto related_dim_out = GetRelatedOutputAxies(temp_target_trans, related_dim_in);
for (auto &i : related_dim_out) {
if (i == pivot_dim_out) {
this->CopyTransform(other, level + 1);
return;
}
}
} else if (level >= isl_map_dim(temp_target_trans.get(), isl_dim_out)) {
LOG(ERROR) << "ComputeAt level: " << level
Expand Down

0 comments on commit eeaba40

Please sign in to comment.