-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup trivial reduction workarounds #2006
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1597,6 +1597,43 @@ std::vector<IterDomain*> IterDomain::clone( | |
return cloned_domains; | ||
} | ||
|
||
IterType inferIterType(IterDomain* i1, IterDomain* i2) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @csarofeen I don't have any concern, but please take a look at here just in case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only nit I have is changing the name to understand it only handles resolution through serial, broadcast, and trivial reduced domains. In other words it shouldn't be used for like gather/stride IDs. Nice cleanup though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for gather and stride, it applies the first rule |
||
// The itertype inference is a pattern matching of the rules below: | ||
// | ||
// X + X = X | ||
// trivial reduction + X = X | ||
// X + trivial reduction = X | ||
// broadcasting + X = X | ||
// X + broadcasting = X | ||
// fail | ||
// | ||
// The rules are proceeded one by one in order. For each rule, we test if the | ||
// given (outer, inner) matches the pattern. If it does, then we stop | ||
// procceeding and get a result. If we have reached the end without finding | ||
// any matched pattern, then it is a mistake and should be reported. | ||
// | ||
// Note that based on the above rule: | ||
// broadcasting + (non-trivial) reduction = reduction | ||
// broadcasting + trivial reduction = broadcasting | ||
if (i1->getIterType() == i2->getIterType()) { | ||
return i1->getIterType(); | ||
} | ||
if (i1->isTrivialReduction()) { | ||
return i2->getIterType(); | ||
} | ||
if (i2->isTrivialReduction()) { | ||
return i1->getIterType(); | ||
} | ||
if (i1->isBroadcast()) { | ||
return i2->getIterType(); | ||
} | ||
if (i2->isBroadcast()) { | ||
return i1->getIterType(); | ||
} | ||
TORCH_CHECK( | ||
false, "Merging IterDomains requires that their iteration types match."); | ||
} | ||
|
||
// Merging does not propagate the start and stop values of the input | ||
// domains to the merged output domain. The actual range of the | ||
// domains is enforced by predicates. Note that since only root | ||
|
@@ -1606,48 +1643,10 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { | |
TORCH_CHECK( | ||
!outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(), | ||
"Merging IterDomains with ending values that are 0 is not supported at this time."); | ||
TORCH_CHECK( | ||
outer->isReduction() == inner->isReduction() || | ||
(!outer->isReduction() && inner->isTrivialReduction()) || | ||
(outer->isTrivialReduction() && !inner->isReduction()), | ||
"Merging IterDomains requires that their iteration types match."); | ||
TORCH_CHECK( | ||
(outer->isGather() && inner->isGather()) || | ||
(!outer->isGather() && !inner->isGather()), | ||
"Merging gather and non-gather domains is not supported."); | ||
|
||
TORCH_CHECK( | ||
!outer->isStride() && !inner->isStride(), | ||
"No support for merging stride domains"); | ||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Val* merged_id_size = mul(outer->extent(), inner->extent()); | ||
|
||
IterType itype = outer->getIterType(); | ||
|
||
if (outer->isBroadcast() && inner->isBroadcast()) { | ||
itype = IterType::Broadcast; | ||
} | ||
|
||
if ((outer->isBroadcast() || inner->isBroadcast()) && | ||
(outer->getIterType() == IterType::Iteration || | ||
inner->getIterType() == IterType::Iteration)) { | ||
itype = IterType::Iteration; | ||
} | ||
|
||
// Merging trivial reduction with iter domain, that's fine, just make it an | ||
// iter domain. | ||
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) && | ||
(outer->getIterType() == IterType::Iteration || | ||
inner->getIterType() == IterType::Iteration)) { | ||
itype = IterType::Iteration; | ||
} | ||
|
||
// Merging trivial reduction with broadcasting, that's fine, just make it a | ||
// broadcasting. | ||
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) && | ||
(outer->isBroadcast() || inner->isBroadcast())) { | ||
itype = IterType::Broadcast; | ||
} | ||
IterType itype = inferIterType(outer, inner); | ||
|
||
Val* expanded_extent = nullptr; | ||
if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now this seems fine to remove, though I'm open minded we want ID based avoidance of inlining certain dimensions, but we probably want a better interface for that.