Skip to content

Commit

Permalink
Fix detection of unmappable root domains (#1952)
Browse files Browse the repository at this point in the history
ComputeAtRootDomainMap flags domains that should not be mapped due to
reductions. Previously, checking if a domain potentially causes an
invalid mapping is only done with one domain in each group of domains
that are found to be mappable so far. That's not actually sufficient as
the unmappable domain set is created just once with no root mapping
information. The fix is to check all consumer domains of a producer
tensor. A small other fix is also done to address a different problem
discovered after the first fix.
  • Loading branch information
naoyam authored Sep 2, 2022
1 parent 90a51f2 commit 8eafc54
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 41 deletions.
103 changes: 64 additions & 39 deletions torch/csrc/jit/codegen/cuda/root_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ class FindInputDomains : BackwardVisitor {
private:
FindInputDomains(TensorView* tv, const IterDomain* id)
: BackwardVisitor(false), tv_(tv) {
input_keys.insert(DomainKey(tv_->domain(), id));
input_keys_.insert(DomainKey(tv_->domain(), id));
}

DomainKeySet find() {
traverseFrom(tv_->fusion(), {tv_});
return input_keys;
return input_keys_;
}

void handle(Expr* expr) override {
Expand All @@ -261,21 +261,21 @@ class FindInputDomains : BackwardVisitor {
.mapConsumerToProducer(out_tv->domain(), in_tv->domain());
for (auto root_dom : out_tv->getRootDomain()) {
DomainKey out_key({out_tv->domain(), root_dom});
if (input_keys.find(out_key) == input_keys.end()) {
if (input_keys_.find(out_key) == input_keys_.end()) {
continue;
}
auto input_id_it = c2p.find(root_dom);
if (input_id_it == c2p.end()) {
continue;
}
DomainKey input_key(in_tv->domain(), input_id_it->second);
input_keys.insert(input_key);
input_keys_.insert(input_key);
}
}

private:
TensorView* tv_ = nullptr;
DomainKeySet input_keys;
DomainKeySet input_keys_;

public:
static DomainKeySet find(TensorView* tv, const IterDomain* id) {
Expand All @@ -297,6 +297,10 @@ void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) {
auto use_chains = DependencyCheck::getAllUseChains(out_tv);
for (const auto& chain : use_chains) {
for (const auto& tv : ir_utils::filterByType<TensorView>(chain)) {
// Do not include the tensor itself in its consumers
if (tv == out_tv) {
continue;
}
const auto& root_domain = tv->getRootDomain();
for (const auto& id : root_domain) {
DomainKey consumer_key(tv->domain(), id);
Expand Down Expand Up @@ -339,30 +343,41 @@ void UnmappableReductionDomains::handle(WelfordOp* op) {
}

bool UnmappableReductionDomains::isReductionOutputMapped(
const std::vector<DomainKey>& consumer_domains,
const DomainKeySet& consumer_domains,
const ComputeAtRootDomainMap& root_map) const {
// Check each reduction domain if any of the consumer domains
// conflicts with it
for (const auto& kv : reduction_domains_) {
const DomainKey& reduction_domain = kv.first;
// Domains that must not be mapped with the reduction domain
const DomainKeySet& incompatible_domains = kv.second;
DomainKey consumer_domain_with_reduction;
bool reduction_found = false;
// Input domains to the reduction domain
const auto& input_keys = reduction_domain_inputs_.at(reduction_domain);
for (const DomainKey& consumer_domain : consumer_domains) {
for (const auto& input_key : input_keys) {
if (input_key == consumer_domain) {
consumer_domain_with_reduction = consumer_domain;
reduction_found = true;
break;
}
}
}
if (!reduction_found) {
// Check if any of the consumer domains is an input to the
// reduction
auto it = std::find_if(
consumer_domains.begin(),
consumer_domains.end(),
[&](const auto& consumer_domain) {
return std::find(
input_keys.begin(), input_keys.end(), consumer_domain) !=
input_keys.end();
});
// None of the consumer domains is used for the reduction
// domain. They should be safe with respect to this reduction
// domain
if (it == consumer_domains.end()) {
continue;
}
// Make sure no incompatible domains will be merged with the reduction
// domain.

// A consumer domain that is an input to the reduction domain
const DomainKey& input_to_reduction = *it;

// Check if mapping input_to_reduction with the other domains in
// consumer_domains. If there's a domain that is a consumer of the
// reduction, they must not be mapped together
for (const auto& consumer_domain : consumer_domains) {
if (consumer_domain == consumer_domain_with_reduction) {
if (consumer_domain == input_to_reduction) {
continue;
}
if (std::any_of(
Expand All @@ -382,6 +397,27 @@ bool UnmappableReductionDomains::isReductionOutputMapped(
return false;
}

std::string UnmappableReductionDomains::toString() const {
std::stringstream ss;
ss << "Reduction-to-consumer map\n";
for (const auto& kv : reduction_domains_) {
ss << "\tReduction: " << kv.first.toString() << "\n";
for (const auto& mapped_val : kv.second) {
ss << "\t\tConsumer domain: " << mapped_val.toString() << "\n";
}
}

ss << "Reduction-to-producer map\n";
for (const auto& kv : reduction_domain_inputs_) {
ss << "\tReduction: " << kv.first.toString() << "\n";
for (const auto& mapped_val : kv.second) {
ss << "\t\tProducer domain: " << mapped_val.toString() << "\n";
}
}

return ss.str();
}

void ComputeAtRootDomainMap::build(bool map_through_reduction) {
// Make sure we start from scratch. Throw away previous results.
eq_set_.clear();
Expand Down Expand Up @@ -724,7 +760,7 @@ void ComputeAtRootDomainMapBuilder::setInvalid(
}

bool ComputeAtRootDomainMapBuilder::isInvalid(
const std::vector<DomainKey>& domains) const {
const DomainKeySet& domains) const {
// First, collect all invalid mappings for each of the keys in domains
DomainKeyMap<DomainKeySet> invalid_key_map;
for (const auto& key : domains) {
Expand All @@ -741,8 +777,9 @@ bool ComputeAtRootDomainMapBuilder::isInvalid(

// Next, check if any pair is invalid to map.
const auto num_keys = domains.size();
const std::vector<DomainKey> domains_vec({domains.begin(), domains.end()});
for (const auto i : c10::irange(num_keys)) {
const auto& key_i = domains[i];
const auto& key_i = domains_vec[i];
// If no invalid keys found for key_i, it can be skipped.
const auto invalid_key_map_it = invalid_key_map.find(key_i);
if (invalid_key_map_it == invalid_key_map.end()) {
Expand All @@ -755,7 +792,7 @@ bool ComputeAtRootDomainMapBuilder::isInvalid(
// If any other key in domains is identified mappable with any of
// the keys in this set, the mapping with key_i is invalid.
for (const auto j : c10::irange(i + 1, num_keys)) {
const auto& key_j = domains[j];
const auto& key_j = domains_vec[j];
if (std::any_of(
invalid_keys_for_i.begin(),
invalid_keys_for_i.end(),
Expand Down Expand Up @@ -1070,26 +1107,14 @@ bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) {
if (domains.size() <= 1) {
return true;
}
// Filter out equivalent domains
std::vector<DomainKey> unique_domains;
for (const auto& domain : domains) {
if (std::none_of(
unique_domains.begin(),
unique_domains.end(),
[&](const auto& unique_dom) {
return root_map_.canMap(domain, unique_dom);
})) {
unique_domains.push_back(domain);
}
}

// Can't map if reduction output domains would be mapped
if (incompatible_domains_.isReductionOutputMapped(
unique_domains, root_map_) &&
if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) &&
!map_through_reduction_) {
return false;
}
// Make sure mapping these domains won't cause any invalid mapping
if (isInvalid(unique_domains)) {
if (isInvalid(domains)) {
return false;
}
return true;
Expand Down
9 changes: 7 additions & 2 deletions torch/csrc/jit/codegen/cuda/root_domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class DomainKey {
return td() == other.td() && id() == other.id() &&
concreteId() == other.concreteId();
}
bool operator!=(const DomainKey& other) const {
return !(*this == other);
}

std::string toString() const;

Expand Down Expand Up @@ -183,9 +186,11 @@ class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor {
//! reduction outputs within the corresponding reduction loop is not
//! possible. This routine is used to build root domain mappings.
bool isReductionOutputMapped(
const std::vector<DomainKey>& consumer_domains,
const DomainKeySet& consumer_domains,
const ComputeAtRootDomainMap& root_map) const;

std::string toString() const;

private:
using IterVisitor::handle;
void handle(ReductionOp* op) override;
Expand Down Expand Up @@ -365,7 +370,7 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder
void setInvalid(const DomainKey& key1, const DomainKey& key2);

//! Check if no pair of domains is invalid to map
bool isInvalid(const std::vector<DomainKey>& domains) const;
bool isInvalid(const DomainKeySet& domains) const;

//! Track a pair of producer-consumer domains as potentially mappable. Inserts
//! entries into pending_map_, but does not add anything into the root_map_
Expand Down
32 changes: 32 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3741,6 +3741,38 @@ TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) {
testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__);
}

// Repro of issue #1950
TEST_F(NVFuserTest, FusionRootMappingRepro1950_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeSymbolicTensor(3);
auto tv1 = makeSymbolicTensor(3);
auto tv2 = makeSymbolicTensor(3);

fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addInput(tv2);

auto tv3 = set(tv0);
auto tv4 = mul(tv1, tv3);
auto tv5 = mul(tv1, tv2);
auto tv6 = mul(tv5, tv3);
auto tv7 = sum(tv6, {2});
auto tv8 = broadcast(tv7, {false, false, true});
auto tv9 = mul(tv3, tv8);

// Issue #1950 was caused by a particular traversal ordering based
// on the output tensor ordering as below
fusion.addOutput(tv9);
fusion.addOutput(tv5);
fusion.addOutput(tv4);

ComputeAtRootDomainMap root_map;
root_map.build();

checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false);
}

TEST_F(NVFuserTest, FusionDetectSelfMappedDomains_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down

0 comments on commit 8eafc54

Please sign in to comment.