Skip to content

Commit

Permalink
Contiguous indexing for View operations (#1990)
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen authored Sep 28, 2022
1 parent a43cb20 commit 967aa77
Show file tree
Hide file tree
Showing 15 changed files with 814 additions and 259 deletions.
617 changes: 497 additions & 120 deletions torch/csrc/jit/codegen/cuda/contiguity.cpp

Large diffs are not rendered by default.

152 changes: 143 additions & 9 deletions torch/csrc/jit/codegen/cuda/contiguity.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,128 @@

#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

// Goes through the transformations associated with a series of ids and root
// ids. Checks the ordering of the iteration domains through these operations to
// pick out which operations are consistently ordered. For example:
// [i0, i1, i2]
// ->split(0, 4)->merge(1)->merge(1)->merge(0)
// are consistently ordered from largest to smallest extents, but
// ->split(0, 4)->merge(1)->merge(0, 2)->merge(0) is not consistently ordered
// with the roots.
//
// This property is important to understand the contiguity of dimensions through
// complex transformations.
class OrderedIdInformation : public OptInDispatch {
public:
OrderedIdInformation() = delete;

OrderedIdInformation(
const std::vector<IterDomain*>& ids,
const std::vector<IterDomain*>& root_domain,
std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info);

const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
idToRootIds() const {
return id_to_root_ids_;
}

bool isConsistentlyOrdered(IterDomain* id) const {
return consistently_ordered_ids_.find(id) !=
consistently_ordered_ids_.end();
}

bool exclusivelyConsumesRoots(IterDomain* id) const {
return exclusively_consumes_roots_.find(id) !=
exclusively_consumes_roots_.end();
}

private:
// Returns if the id in active_ids should be in exclusively_consumes_roots_
bool checkExclusivelyConsumesRoots(IterDomain* id);

void handle(Split*) override;

void handle(Merge* merge) override;

void handle(Swizzle2D* swizzle) override;

// Track which root ids were used to generate each iter domain
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
id_to_root_ids_;

// Track all IterDomains that have correct ordered transforms for contiguity.
// i.e. if we have:
//
// root = [i0, i1, i2]
// i3 = merge(i0, i2)
// would not be consistently ordered transformed
//
// root = [i0, i1, i2]
// i4, i5 = spit(merge(merge(i0, i1), i2), 4)
// would be consistently ordered transforms
//
// root = [i0, i1, i2, i3]
// i4 = merge(i1, i2) would also be consistently ordered transformed
std::unordered_set<IterDomain*> consistently_ordered_ids_;

// Active series of IterDomains that are updated while we're processing the
// domain. Helps us identify which ids are consistently_ordered_ids_. Used
// for intermediate storage, not to return.
std::vector<IterDomain*> active_ids_;

// IterDomains in this set exclusively consume all the uses of their roots.
// For example:
// [i0, i1] split(0, f)->merge(1)
// [ceilDiv(i0, f), f*i1]
// neither iter domains exclusively consume the roots. With another:
// merge(0) -> [ceilDiv(i0, f)*f*i1]
// The resulting iter domain does exclusively consume the roots.
//
// Also:
// [i0, i1, i2, i3] merge(1)->merge(1)
// ->[i0, i1*i2*i3]
// both resulting iter domains do exclusively consume their roots
std::unordered_set<IterDomain*> exclusively_consumes_roots_;

// Broadcast domains that are concretized cannot be considered contiguously
// indexable.
// TODO: This constraint is more conservative than necessary as it's only if
// the domain is concretized within the local indexing, not in the entire
// fusion.
std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info_;
};

// Based on provided divisible split set, goes through expressions and marks all
// IterDomains that are dependent on a non-divisible split.
class NonDivisibleSplitDependencies : public OptInDispatch {
public:
NonDivisibleSplitDependencies() = delete;

NonDivisibleSplitDependencies(
const std::vector<IterDomain*>& ids,
const std::vector<IterDomain*>& root_domain,
const std::unordered_set<Split*>& divisible_splits);

bool dependsOnNonDivisibleSplit(IterDomain* id) const {
return depends_on_non_divisible_split.find(id) !=
depends_on_non_divisible_split.end();
}

private:
std::unordered_set<IterDomain*> depends_on_non_divisible_split;
};

// A merge is contiguous if:
// Inputs of outer are to the left in the root domain of the inputs of RHS.
// All inputs are contiguous in the root domain:
Expand All @@ -22,8 +137,6 @@ namespace cuda {

class ContigIDs : public OptInDispatch {
public:
ContigIDs() = delete;

//! Check through the history of ids whose inputs map to root_domain with
//! contiguity root_contiguity. Return unordered_set of all merges that are
//! contiguous. Ignore root order is primarily used for predicate generation.
Expand All @@ -42,21 +155,28 @@ class ContigIDs : public OptInDispatch {
//! If ignore_indexability and ignore_halo_constraint are true,
//! ignore the constraint on indexing and halo, respectively. It is
//! the caller that is responsible for its correctness.
//!
//! The function interface with many parameters looks ugly, but it
//! is also important to make ignore_indexability and
//! ignore_halo_constraint explicit to avoid any surprise.
//!
//! Not really sure why but clang-tidy only complains about
//! std::unordered_map if passed as a const reference.
ContigIDs(
const std::vector<IterDomain*>& ids,
const std::vector<IterDomain*>& root_domain,
const std::vector<bool>& root_contiguity,
std::unordered_map<IterDomain*, IterDomain*> concrete_to_ref,
const std::unordered_set<Split*>& divisible_splits,
std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {},
bool ignore_indexability = false,
bool ignore_halo_constraint = false);
bool ignore_indexability = false);

ContigIDs(
const std::vector<IterDomain*>& ids,
const std::vector<IterDomain*>& root_domain,
const std::vector<bool>& root_contiguity,
std::unordered_map<IterDomain*, IterDomain*> concrete_to_ref,
const std::unordered_set<Split*>& divisible_splits,
std::shared_ptr<const ComputeAtMap> ca_map,
std::shared_ptr<const HaloInfo> halo_info,
std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info,
std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {},
bool ignore_indexability = false);

const std::unordered_set<IterDomain*>& contigIDs() const {
return contig_ids_;
Expand Down Expand Up @@ -107,13 +227,23 @@ class ContigIDs : public OptInDispatch {
IterDomain* getMappedId(IterDomain* id) const;

private:
void build(const std::vector<IterDomain*>& ids);

//! Root domains to analyze contiguity
const std::vector<IterDomain*>& root_domain_;
//! Contiguity of root_domain_
const std::vector<bool>& root_contiguity_;
//! Mapping of concrete to reference domains. If a concrete domain
//! is not mapped, it is not indexable as there's no mapped index.
const std::unordered_map<IterDomain*, IterDomain*> concrete_to_ref_;
// Divisible split information as we can still consider iter domains
// contiguous through divisible splits.
const std::unordered_set<Split*>& divisible_splits_;

std::shared_ptr<const ComputeAtMap> ca_map_;
std::shared_ptr<const HaloInfo> halo_info_;
std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info_;

//! Producer-to-consumer index map in the case of analyzing replayed
//! producer tensors
const std::unordered_map<IterDomain*, IterDomain*> p2c_id_map_;
Expand All @@ -129,6 +259,10 @@ class ContigIDs : public OptInDispatch {
//! Mapping of root domain to the actual indexed domain, which can
//! be itself or a contig merged domain if found.
std::unordered_map<IterDomain*, IterDomain*> root_to_indexed_id_;

std::unique_ptr<const OrderedIdInformation> consistent_transform_info_;

NonDivisibleSplitDependencies non_divisible_id_info_;
};

} // namespace cuda
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,13 @@ template <typename IRContext>
void PrecomputedValuesBase<IRContext>::validate() {
FUSER_PERF_SCOPE("PrecomputedValuess::Validate");
for (auto it : binding_log_) {
TORCH_INTERNAL_ASSERT(values_[it.first] == it.second);
TORCH_INTERNAL_ASSERT(
values_[it.first] == it.second,
"Precomputed values failed to validate.",
"\nSomething unexpected changed between the compilation and execution.\n",
values_[it.first],
" != ",
it.second);
}
has_valid_values_ = true;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3190,7 +3190,7 @@ class ForceHalfAnnotation : public IterVisitor {
val->getDataType().value() == DataType::BFloat16);
});

annotation.traverseFrom(fusion, fp16_outputs);
annotation.traverseTo(fusion, fp16_outputs);
return annotation.force_fp16_tv_set_;
}

Expand Down
18 changes: 3 additions & 15 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,7 @@ IndexCompute::IndexCompute(
std::move(extent_map),
std::move(zero_domains),
std::move(zero_merged_in),
ContigIDs(
_td->domain(),
_td->getMaybeRFactorDomain(),
std::vector<bool>(_td->getMaybeRFactorDomain().size(), false),
{}),
ContigIDs({}, {}, {}, {}, {}),
std::move(preferred_paths),
std::move(halo_extent_map)) {}

Expand Down Expand Up @@ -755,7 +751,7 @@ void IndexCompute::run() {
const std::vector<Val*> domain_vals(
td_->domain().begin(), td_->domain().end());

traverseFrom(td_->fusion(), domain_vals, false);
traverseTo(td_->fusion(), domain_vals, false);
}

IterDomain* IndexCompute::maybeGetExactMapConcreteID(IterDomain* id) {
Expand Down Expand Up @@ -851,7 +847,7 @@ class UpdateLeafIndices : public IterVisitor {
const std::vector<Val*> domain_vals(
td_->domain().begin(), td_->domain().end());

traverseFrom(td_->fusion(), domain_vals, false);
traverseTo(td_->fusion(), domain_vals, false);
}

const std::unordered_map<IterDomain*, Val*>& indexMap() const {
Expand Down Expand Up @@ -2954,14 +2950,6 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(

auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv);

// Indexing is done without considering contig merging. Actual
// predicated domains are determined by considering contiguity.
const ContigIDs contig_finder(
consumer_tv->domain()->domain(),
consumer_tv->getMaybeRFactorDomain(),
std::vector<bool>(consumer_tv->getMaybeRFactorDomain().size(), false),
{});

// Generate start and stop indexing from idgraph.
//
// Both start and stop positions may need to be predicated. Indexing
Expand Down
Loading

0 comments on commit 967aa77

Please sign in to comment.