Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup trivial reduction workarounds #2006

Merged
merged 6 commits into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions torch/csrc/jit/codegen/cuda/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,25 @@ size_t MaxPosCalculator::getMaxPosAll(
return max_pos;
}

void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) {
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids);
void inlineMost() {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this seems fine to remove, though I'm open minded we want ID based avoidance of inlining certain dimensions, but we probably want a better interface for that.

inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()));
}

void inlineMost(
const std::vector<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
void inlineMost(const std::vector<TensorView*>& tvs) {
if (tvs.empty()) {
return;
}
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto tv : tvs) {
tv->inlineAt(-1, true, &calc);
}
}

void inlineMost(
const std::unordered_set<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
void inlineMost(const std::unordered_set<TensorView*>& tvs) {
if (tvs.empty()) {
return;
}
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto tv : tvs) {
tv->inlineAt(-1, true, &calc);
}
Expand Down Expand Up @@ -276,10 +272,9 @@ std::unordered_map<TensorView*, size_t> getPositionsMappedTo(
void inlineAllAt(
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
bool best_effort) {
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto pair : mapped_positions) {
pair.first->inlineAt(pair.second, best_effort, &calc);
}
Expand All @@ -289,10 +284,9 @@ void inlineSelectedAt(
const std::unordered_set<TensorView*>& selected,
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort,
const std::unordered_set<IterDomain*>& uninlinable_ids) {
bool best_effort) {
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
MaxPosCalculator calc(uninlinable_ids);
MaxPosCalculator calc;
for (auto pair : mapped_positions) {
if (selected.count(pair.first) > 0) {
pair.first->inlineAt(pair.second, best_effort, &calc);
Expand Down
17 changes: 5 additions & 12 deletions torch/csrc/jit/codegen/cuda/inlining.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,28 @@ class MaxPosCalculator {

// Inline to the right most allowed position for all tensors in the current
// fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost();
// Inline to the right most allowed position for the selected tensors in the
// current fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::vector<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost(const std::vector<TensorView*>& tvs);
// Inline to the right most allowed position for the selected tensors in the
// current fusion.
TORCH_CUDA_CU_API void inlineMost(
const std::unordered_set<TensorView*>& tvs,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
TORCH_CUDA_CU_API void inlineMost(const std::unordered_set<TensorView*>& tvs);

// Inline to the position corresponding to the reference position in the
// reference tensor for all tensors in the current fusion.
TORCH_CUDA_CU_API void inlineAllAt(
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort = false,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
bool best_effort = false);

// Inline to the position corresponding to the reference position in the
// reference tensor for selected tensors in the current fusion.
TORCH_CUDA_CU_API void inlineSelectedAt(
const std::unordered_set<TensorView*>& selected,
TensorView* reference_tv,
int64_t reference_pos,
bool best_effort = false,
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
bool best_effort = false);

} // namespace cuda
} // namespace fuser
Expand Down
77 changes: 38 additions & 39 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,43 @@ std::vector<IterDomain*> IterDomain::clone(
return cloned_domains;
}

IterType inferIterType(IterDomain* i1, IterDomain* i2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csarofeen I don't have any concern, but please take a look at here just in case.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only nit I have is changing the name to understand it only handles resolution through serial, broadcast, and trivial reduced domains. In other words it shouldn't be used for like gather/stride IDs. Nice cleanup though.

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for gather and stride, it applies the first rule X + X = X. Do we want to do it differently?

// The itertype inference is a pattern matching of the rules below:
//
// X + X = X
// trivial reduction + X = X
// X + trivial reduction = X
// broadcasting + X = X
// X + broadcasting = X
// fail
//
// The rules are proceeded one by one in order. For each rule, we test if the
// given (outer, inner) matches the pattern. If it does, then we stop
// procceeding and get a result. If we have reached the end without finding
// any matched pattern, then it is a mistake and should be reported.
//
// Note that based on the above rule:
// broadcasting + (non-trivial) reduction = reduction
// broadcasting + trivial reduction = broadcasting
if (i1->getIterType() == i2->getIterType()) {
return i1->getIterType();
}
if (i1->isTrivialReduction()) {
return i2->getIterType();
}
if (i2->isTrivialReduction()) {
return i1->getIterType();
}
if (i1->isBroadcast()) {
return i2->getIterType();
}
if (i2->isBroadcast()) {
return i1->getIterType();
}
TORCH_CHECK(
false, "Merging IterDomains requires that their iteration types match.");
}

// Merging does not propagate the start and stop values of the input
// domains to the merged output domain. The actual range of the
// domains is enforced by predicates. Note that since only root
Expand All @@ -1606,48 +1643,10 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
TORCH_CHECK(
!outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
"Merging IterDomains with ending values that are 0 is not supported at this time.");
TORCH_CHECK(
outer->isReduction() == inner->isReduction() ||
(!outer->isReduction() && inner->isTrivialReduction()) ||
(outer->isTrivialReduction() && !inner->isReduction()),
"Merging IterDomains requires that their iteration types match.");
TORCH_CHECK(
(outer->isGather() && inner->isGather()) ||
(!outer->isGather() && !inner->isGather()),
"Merging gather and non-gather domains is not supported.");

TORCH_CHECK(
!outer->isStride() && !inner->isStride(),
"No support for merging stride domains");
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved

Val* merged_id_size = mul(outer->extent(), inner->extent());

IterType itype = outer->getIterType();

if (outer->isBroadcast() && inner->isBroadcast()) {
itype = IterType::Broadcast;
}

if ((outer->isBroadcast() || inner->isBroadcast()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
}

// Merging trivial reduction with iter domain, that's fine, just make it an
// iter domain.
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
(outer->getIterType() == IterType::Iteration ||
inner->getIterType() == IterType::Iteration)) {
itype = IterType::Iteration;
}

// Merging trivial reduction with broadcasting, that's fine, just make it a
// broadcasting.
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
(outer->isBroadcast() || inner->isBroadcast())) {
itype = IterType::Broadcast;
}
IterType itype = inferIterType(outer, inner);

Val* expanded_extent = nullptr;
if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) {
Expand Down
7 changes: 1 addition & 6 deletions torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,8 @@ void multiReductionInliner(
}
}

// Find iter domains that are mapped to a trivial reduction, these should
// never be inlined.
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
scheduler_utils::getTrivialReductionMap(fusion);

// Inline the schedule
inlineMost(mapped_to_trivial_reduction);
inlineMost();
}

namespace {
Expand Down
86 changes: 13 additions & 73 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,20 @@ namespace scheduler_utils {

// Returns number of "valid" dimensions. e.g. if tv has
// [I1, R2, I3, I4, R3{1}]
// where R3{1} is in dont_merge, resulting domain should be:
// [I1, I3*I4, R2, R3{1}] with return value 3
// resulting domain should be:
// [I1, I3*I4, R2*R3{1}] with return value 3
//
// if tv has
// [R1, I2, R3, I4, R4, R5{1}, R6{1}]
// where R5{1} and R6{1} are in dont_merge, resulting domain should be:
// [I2*I4, R1*R3, R4, R5{1}, R6{1}]
// resulting domain should be:
// [I2*I4, R1*R3, R4*R5{1}*R6{1}]
// with return value 3
size_t merge_3d(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t merge_3d(TensorView* tv) {
bool active_is_reduction = false;
bool first_dim = true;
int prev_i = -1;

for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (dont_merge.count(tv->axis(i))) {
continue;
}

if (first_dim) {
active_is_reduction = tv->axis(i)->isReduction();
prev_i = i;
Expand All @@ -67,10 +61,6 @@ size_t merge_3d(

for (int i = static_cast<int>(tv->nDims()) - 2; i >= 0; i--) {
auto id = tv->axis(i);
if (dont_merge.count(id)) {
continue;
}

if (first_dim) {
active_is_reduction = id->isReduction();
prev_i = i;
Expand All @@ -96,10 +86,6 @@ size_t merge_3d(
prev_i = -1;

for (int i = static_cast<int>(tv->nDims()) - 3; i >= 0; i--) {
if (dont_merge.count(tv->axis(i))) {
continue;
}

if (first_dim) {
active_is_reduction = tv->axis(i)->isReduction();
prev_i = i;
Expand All @@ -114,7 +100,7 @@ size_t merge_3d(
if (prev_i == -1) {
// Two dimensional, put merged dimensions first
tv->reorder({{-1, 0}, {-2, 1}});
// [outer, inner, dont_merge...]
// [outer, inner]
if (tv->axis(0)->isReduction()) {
// put reductions as second axis
tv->reorder({{0, 1}, {1, 0}});
Expand Down Expand Up @@ -195,13 +181,11 @@ c10::optional<size_t> mergeDims(
return left;
}

size_t mergeReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t mergeReduction(TensorView* tv) {
int prev_i = -1;
size_t num_merged = 0;
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
if (!tv->axis(i)->isReduction()) {
continue;
}
if (prev_i == -1) {
Expand All @@ -219,16 +203,14 @@ size_t mergeReduction(
return prev_i == -1 ? 0 : num_merged + 1;
}

size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge) {
size_t mergeNonReduction(TensorView* tv) {
int prev_i = -1;
size_t num_merged = 0;
if (tv->nDims() == 0) {
return 0;
}
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
if (tv->axis(i)->isReduction()) {
continue;
}
if (prev_i == -1) {
Expand Down Expand Up @@ -905,63 +887,21 @@ PersistentBufferSizeReturn persistentBufferSize(
return persistent_buffer_size;
}

std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion) {
auto all_tvs = ir_utils::allTvs(fusion);
std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
for (auto tv : all_tvs) {
// root domain vs domain shouldn't matter as at this point we shouldn't have
// any transformations.
for (auto id : tv->getRootDomain()) {
if (id->isTrivialReduction()) {
mapped_to_trivial_reduction.emplace(id);
}
}
}

if (!mapped_to_trivial_reduction.empty()) {
// Use the loop map as that is the most permissive
auto ca_map = ComputeAtMap(fusion);
// Make a copy we need to check mappings of all
auto trivial_ids = mapped_to_trivial_reduction;
for (auto tv : all_tvs) {
for (auto id : tv->getRootDomain()) {
if (!id->extent()->isOneInt()) {
continue;
}
if (std::any_of(
trivial_ids.begin(),
trivial_ids.end(),
[&ca_map, &id](IterDomain* trivial_id) {
return ca_map.areMapped(
id, trivial_id, IdMappingMode::PERMISSIVE);
})) {
mapped_to_trivial_reduction.emplace(id);
}
}
}
}
return mapped_to_trivial_reduction;
}

std::pair<bool, bool> canonicalDimReduction(
Fusion* fusion,
TensorView* tv,
bool schedule_3D) {
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
getTrivialReductionMap(fusion);

TORCH_INTERNAL_ASSERT(tv != nullptr);

if (!schedule_3D) {
// We coalesce all reduction axes to the right;
bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0;
bool has_red_axis = mergeReduction(tv) > 0;

bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0;
bool has_iter_axis = mergeNonReduction(tv) > 0;
return {has_iter_axis, has_red_axis};
} else {
TORCH_INTERNAL_ASSERT(
merge_3d(tv, mapped_to_trivial_reduction) == 3,
"Tried 3D merge, but result is not 3D.");
merge_3d(tv) == 3, "Tried 3D merge, but result is not 3D.");
return {true, true};
}
}
Expand Down
12 changes: 4 additions & 8 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,12 @@ TORCH_CUDA_CU_API inline c10::optional<size_t> mergeDims(
}

// Merge all reduction to the right side and returns total number of
// reduction axes. Don't merge is typically used for trivial reductions.
size_t mergeReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// reduction axes.
size_t mergeReduction(TensorView* tv);

// merge all non-reduction axes to the left side and returns total number of
// iteration axes. Don't merge is typically used for trivial reductions.
size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// iteration axes.
size_t mergeNonReduction(TensorView* tv);

// Propagate the parallelization from the selected dimensions of the reference
// tensor to their corresponding dimensions in all selected tensors in the DAG.
Expand Down
Loading