Skip to content

Commit

Permalink
Improve divisible split detection (#1970)
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen authored Sep 27, 2022
1 parent 42ccc52 commit 4cbe0db
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 30 deletions.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp",
"torch/csrc/jit/codegen/cuda/lower_allocation.cpp",
"torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp",
"torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp",
"torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp",
"torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp",
"torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp",
Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ void mapMaybeSwizzleOp(
}
}

bool IterDomainGraph::exprsMap(Expr* first, Expr* second, bool forward) {
bool IterDomainGraph::exprsMap(
Expr* first,
Expr* second,
bool forward,
const DisjointSets<IterDomain*>& id_map) {
if (first == nullptr || second == nullptr) {
return false;
}
Expand Down Expand Up @@ -117,8 +121,7 @@ bool IterDomainGraph::exprsMap(Expr* first, Expr* second, bool forward) {
zipped_ids.begin(),
zipped_ids.end(),
[&](std::pair<IterDomain*, IterDomain*> id_pair) {
return !exact_nodes_.strictAreMapped(
id_pair.first, id_pair.second);
return !id_map.strictAreMapped(id_pair.first, id_pair.second);
})) {
return false;
}
Expand Down Expand Up @@ -167,7 +170,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
return;
}

if (!exprsMap(first, second, forward)) {
if (!exprsMap(first, second, forward, exact_nodes_)) {
return;
}

Expand Down
55 changes: 30 additions & 25 deletions torch/csrc/jit/codegen/cuda/compute_at_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,41 +88,46 @@ class TORCH_CUDA_CU_API IterDomainGraph {
return view_rfactor_ids_;
}

bool hasSelfMapping() const {
return self_mapping_info_.has_value();
}
// Returns if first and second are expressions through which the provided
// id_map have matching inputs (if forward), or outputs (if not forward).
// Returning true means the expressions are "the same", in terms they modify
// matching original extents, by the same amount.
static bool exprsMap(
Expr* first,
Expr* second,
bool forward,
const DisjointSets<IterDomain*>& id_map);
}

private:
void build(Fusion* fusion);
bool hasSelfMapping() const {
return self_mapping_info_.has_value();
}

void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id);
private:
void build(Fusion* fusion);

// Returns if first and second are expressions with inputs match through exact
// map (if forward), or outputs match (if not forward).
bool exprsMap(Expr* first, Expr* second, bool forward);
void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id);

// Checks if exprsMap then if forward will map outputs else inputs in exact
// and permissive map.
void mapThroughExpr(Expr* first, Expr* second, bool forward);
// Checks if exprsMap then if forward will map outputs else inputs in exact
// and permissive map.
void mapThroughExpr(Expr* first, Expr* second, bool forward);

DisjointSets<IterDomain*> permissive_nodes_;
DisjointSets<IterDomain*> exact_nodes_;
DisjointSets<IterDomain*> loop_nodes_;
DisjointSets<IterDomain*> permissive_nodes_;
DisjointSets<IterDomain*> exact_nodes_;
DisjointSets<IterDomain*> loop_nodes_;

// Consumers and producers is not symmetric like the other sets
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
consumers_;
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
producers_;
// Consumers and producers is not symmetric like the other sets
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>> consumers_;
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>> producers_;

DisjointSets<IterDomain*> sibling_sets_;
DisjointSets<IterDomain*> sibling_sets_;

VectorOfUniqueEntries<IterDomain*> all_ids_;
VectorOfUniqueEntries<IterDomain*> all_ids_;

std::unordered_set<IterDomain*> view_rfactor_ids_;
std::unordered_set<IterDomain*> view_rfactor_ids_;

c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
self_mapping_info_ = c10::nullopt;
c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
self_mapping_info_ = c10::nullopt;
};

class TrivialReductionInfo;
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
#include <torch/csrc/jit/codegen/cuda/lower_divisible_split.h>
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
#include <torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h>
Expand Down Expand Up @@ -256,6 +257,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {

compute_at_map_->validateAndPropagatePType();

// Uses compute_at_map, find all splits that are enforced to be divisible
divisible_splits_ = getAllDivisibleSplits(fusion_, compute_at_map_.get());

// Used in parallel dimension map
concretized_broadcast_domains_ =
std::make_shared<const ConcretizedBroadcastDomains>(fusion_);
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
return non_divisible_split_info_;
}

const auto& divisbleSplitSet() const {
return divisible_splits_;
}

DoubleBufferInfo& doubleBufferInfo() {
return double_buffer_info_;
}
Expand Down Expand Up @@ -212,6 +216,7 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
FusedReductionInfo fused_reduction_info_;
SyncMap sync_map_;
kir::KernelPerformanceProfile profile_;
std::unordered_set<Split*> divisible_splits_;

// Track which tensor views are inputs or outputs of a vectorized operation
// and their maximum vectorized access size
Expand Down
121 changes: 121 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

#include <torch/csrc/jit/codegen/cuda/lower_divisible_split.h>

#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>

#include <unordered_set>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

std::unordered_set<Split*> getAllDivisibleSplits(Fusion* fusion) {
ComputeAtMap ca_map(fusion);
return getAllDivisibleSplits(fusion, &ca_map);
}

std::unordered_set<Split*> getAllDivisibleSplits(
Fusion* fusion,
const ComputeAtMap* ca_map) {
std::unordered_set<Split*> all_divisible_splits;

auto all_tvs = ir_utils::allTvs(fusion);
// Find all tensor views with a view like rfactor. Splits used in view
// transformations must be divisible by definition.
for (auto tv : all_tvs) {
auto rfactor_dom = tv->getMaybeRFactorDomain();
// Not view if there's no rfactor axis
if (!tv->domain()->hasViewLikeRFactor()) {
continue;
}

// Take the view transformations and add all the splits. Those splits are
// the only divisible splits.
auto view_exprs =
StmtSort::getExprs(fusion, {rfactor_dom.begin(), rfactor_dom.end()});
auto split_exprs = ir_utils::filterByType<Split>(view_exprs);
all_divisible_splits.insert(split_exprs.begin(), split_exprs.end());
}

// Vectorized dimensions are enforced to be a result of divisible splits.
// Gather vectorized splits.
for (auto tv : all_tvs) {
auto vec_id_it = std::find_if(
tv->domain()->domain().begin(),
tv->domain()->domain().end(),
[](IterDomain* id) {
return isParallelTypeVectorize(id->getParallelType());
});

if (vec_id_it == tv->domain()->domain().end()) {
continue;
}

// We could have a case technically like:
// [8, 2] where we do:
// split(0, 2)
// merge(1)
// so it ends up as [4, 4]
// split(0, 2) must be divisible, but for now we're not going to capture
// cases like this. Just look for direct split's producing a vectorize
// dimension.
auto vec_id = *vec_id_it;
if (vec_id->definition() != nullptr && vec_id->definition()->isA<Split>()) {
all_divisible_splits.emplace(vec_id->definition()->as<Split>());
}
}

// If there's no view like splits, there's nothing to find
if (all_divisible_splits.empty()) {
return all_divisible_splits;
}

// Track the concrete id in the exact map of the outer output of the split
// expressions. This is how we'll check if there are matching splits. This
// also gets rid of any splits that already match (for processing).
std::unordered_map<IterDomain*, Expr*> outer_concrete_id_to_expr;

for (auto split : all_divisible_splits) {
outer_concrete_id_to_expr[ca_map->getConcreteMappedID(
split->outer(), IdMappingMode::EXACT)] = split;
}

std::unordered_set<Expr*> visited(
all_divisible_splits.begin(), all_divisible_splits.end());

// Find splits that match what we already have:
for (auto entry : outer_concrete_id_to_expr) {
auto concrete_id = entry.first;
auto original_view_split = entry.second;

const auto& exact_mapped_ids =
ca_map->idGraph().exactNodes().getDisjointSetOf(concrete_id).vector();
for (auto other_id : exact_mapped_ids) {
if (other_id->definition() == nullptr) {
continue;
}

if (!visited.emplace(other_id->definition()).second) {
// Already visited
continue;
}

if (IterDomainGraph::exprsMap(
original_view_split,
other_id->definition(),
false,
ca_map->idGraph().exactNodes())) {
all_divisible_splits.emplace(other_id->definition()->as<Split>());
}
}
}

return all_divisible_splits;
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
29 changes: 29 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_divisible_split.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

// Looks through all transformations assocaited with view, or enforced divisible
// vectorization splits and gathers all splits that provably don't have a
// remainder, therefore the extents of the associated IterDomains do not require
// a ceilDiv expressions.
TORCH_CUDA_CU_API std::unordered_set<Split*> getAllDivisibleSplits(
Fusion* fusion);

// Same as above but will use provided ComputeAtMap instead of building its own.
TORCH_CUDA_CU_API std::unordered_set<Split*> getAllDivisibleSplits(
Fusion* fusion,
const ComputeAtMap* ca_map);

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
11 changes: 10 additions & 1 deletion torch/csrc/jit/codegen/cuda/non_divisible_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,16 @@ void NonDivisibleSplitInfo::handle(Split* split) {
splits_to_validate_.insert(split);
} else {
// Not proven to be a divisible split
splits_to_predicate_[current_tv_].push_back(split);
auto gpu_lower = GpuLower::current();
TORCH_INTERNAL_ASSERT(gpu_lower != nullptr);

// If we know this split must be divisible, it's either validated as
// above, exact matches to a case matching the above, or exact matches
// to a transformation from view which must be divisible.
if (gpu_lower->divisbleSplitSet().find(split) ==
gpu_lower->divisbleSplitSet().end()) {
splits_to_predicate_[current_tv_].push_back(split);
}
}

is_protected = true;
Expand Down
Loading

0 comments on commit 4cbe0db

Please sign in to comment.