diff --git a/build_variables.bzl b/build_variables.bzl index defa57daed6445..f70d4280825af0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -649,7 +649,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/autograd/functions/comm.cpp", "torch/csrc/jit/codegen/cuda/arith.cpp", "torch/csrc/jit/codegen/cuda/compute_at.cpp", - "torch/csrc/jit/codegen/cuda/inline_propagator.cpp", + "torch/csrc/jit/codegen/cuda/inlining.cpp", "torch/csrc/jit/codegen/cuda/compute_at_map.cpp", "torch/csrc/jit/codegen/cuda/codegen.cpp", "torch/csrc/jit/codegen/cuda/contiguity.cpp", diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index ae6231614b7ff6..d8f950848f8fcb 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -213,20 +213,21 @@ void ComputeAt::runAt( auto selected = getPropagationSubgraph(producer, consumer); ComputeAtSelector selector(selected); - InlinePropagator inline_propagator( - consumer, consumer_position, mode, selector.selected()); - MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); if (mode == ComputeAtMode::MostInlined) { MostInlinedTransformPropagator propagator; path.traverse(&propagator); + inlineMost(selected); } else { TransformPropagator propagator(consumer, consumer_position); path.traverse(&propagator); + inlineSelectedAt( + selected, + consumer, + consumer_position, + mode == ComputeAtMode::BestEffort); } - - path.traverse(&inline_propagator); } void ComputeAt::runWith( @@ -253,19 +254,21 @@ void ComputeAt::runWith( auto selected = getPropagationSubgraph(producer, consumer); ComputeAtSelector selector(selected); - InlinePropagator inline_propagator( - producer, producer_position, mode, selector.selected()); - MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); if (mode == ComputeAtMode::MostInlined) { MostInlinedTransformPropagator propagator; path.traverse(&propagator); + inlineMost(selected); } else { TransformPropagator propagator(producer, producer_position); path.traverse(&propagator); + inlineSelectedAt( + selected, + producer, + producer_position, + mode == ComputeAtMode::BestEffort); } - path.traverse(&inline_propagator); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 98100334d72b6e..d3d3fdb299dd69 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/grouped_reduction.cpp b/torch/csrc/jit/codegen/cuda/grouped_reduction.cpp index 5931eb3427aa97..d907a0665e9f64 100644 --- a/torch/csrc/jit/codegen/cuda/grouped_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/grouped_reduction.cpp @@ -38,7 +38,7 @@ bool hasMatchingTransformations(TensorView* ref, TensorView* other) { } // Validate grouping of reductions and return a new max producer position -unsigned int validateReductionGrouping( +void validateReductionGrouping( const std::vector& inputs, const std::vector& outputs) { TORCH_INTERNAL_ASSERT(inputs.size() == outputs.size()); @@ -57,7 +57,6 @@ unsigned int validateReductionGrouping( const auto num_root_dims = ref_domain.size(); const auto num_dims = ref_tv->nDims(); const auto ref_ca_pos = ref_tv->getComputeAtPosition(); - auto max_producer_pos = ref_tv->getMaxProducerPosition(); for (const auto i : c10::irange(inputs.size())) { auto output_tv = outputs.at(i)->as(); const auto& output_domain = output_tv->getRootDomain(); @@ -136,9 +135,6 @@ unsigned int validateReductionGrouping( ref_tv->toString(), ". Mismatched tensor: ", output_tv->toString()); - - max_producer_pos = - std::max(max_producer_pos, output_tv->getMaxProducerPosition()); } // Must not have any data dependency from outputs to inputs @@ -152,8 +148,6 @@ unsigned int validateReductionGrouping( } TORCH_INTERNAL_ASSERT(all_dep_vals.empty(), ss.str()); } - - return max_producer_pos; } } // namespace @@ -194,14 +188,14 @@ void groupReductions(const std::vector& reduction_outputs) { inputs.at(i) = rop->in(); } - auto max_producer_pos = validateReductionGrouping(inputs, outputs); - - for (auto output : ir_utils::filterByType(outputs)) { - output->setMaxProducer(max_producer_pos); - } + validateReductionGrouping(inputs, outputs); IrBuilder::create( container, op_types, init_vals, outputs, inputs); + + for (auto output : ir_utils::filterByType(outputs)) { + output->updateMaxProducerPosition(); + } } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp deleted file mode 100644 index a5edae083a32a4..00000000000000 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ /dev/null @@ -1,385 +0,0 @@ -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -MaxPosCalculator::MaxPosCalculator( - ComputeAtMode mode, - std::unordered_set uninlinable_ids) - : mode_(mode), uninlinable_ids_(std::move(uninlinable_ids)) { - buildUnmappableDims(); -} - -void MaxPosCalculator::buildUnmappableDims() { - ComputeAtRootDomainMap root_map; - root_map.build(); - - auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); - for (auto tv : all_tvs) { - auto consumers = ir_utils::consumerTvsOf(tv); - for (auto consumer : consumers) { - // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline non-trivial - // reduction structures. - auto mappable_roots = - root_map.getMappableDims(tv->domain(), consumer->domain()); - for (auto tv_root_id : tv->getMaybeRFactorDomain()) { - if (mappable_roots.find(tv_root_id) == mappable_roots.end() && - !tv_root_id->isTrivialReduction()) { - unmappable_dims_.emplace(tv_root_id); - } - } - } - } -} - -bool MaxPosCalculator::isAllowedID( - IterDomain* id, - TensorView* tv, - bool allow_reduction, - bool allow_vectorize, - bool allow_unmappable) const { - bool allowed = true; - - if (!allow_reduction) { - allowed = allowed && !id->isReduction(); - } - - if (uninlinable_ids_.count(id)) { - return false; - } - - if (!allow_vectorize) { - // Avoid inlining if marked as Vectorize or Group. In the case of - // BestEffort and MostInlined modes, avoid Unroll as well. - bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) || - id->getParallelType() == ParallelType::Group || - ((mode_ == ComputeAtMode::BestEffort || - mode_ == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - allowed = allowed && !is_vectorize; - } - - if (!allow_unmappable) { - auto root_dom = tv->getMaybeRFactorDomain(); - std::unordered_set root_dom_set(root_dom.begin(), root_dom.end()); - auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); - bool is_unmappable = false; - for (auto val : all_vals) { - auto id = val->as(); - if (root_dom_set.count(val) > 0 && unmappable_dims_.count(id) > 0) { - is_unmappable = true; - break; - } - } - allowed = allowed && !is_unmappable; - } - - return allowed; -} - -size_t MaxPosCalculator::getMaxPosSelf( - TensorView* tv, - bool allow_reduction, - bool allow_vectorize, - bool allow_unmappable) const { - auto dom = tv->domain()->domain(); - auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) { - return !isAllowedID( - id, tv, allow_reduction, allow_vectorize, allow_unmappable); - }); - return std::distance(dom.begin(), iter); -} - -// Return the max position in producer that can be inlined to consumer -// Cannot inline: -// Vectorized dimensions in consumer -// Unrolled dimensions in consumer -size_t MaxPosCalculator::getMaxProducerPosFromConsumer( - TensorView* producer, - TensorView* consumer) const { - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto replay_CasP = - BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); - auto p2c_replay_map = replay_CasP.getReplay(); - - for (size_t producer_pos = 0; producer_pos < producer->nDims(); - producer_pos++) { - // If the producer position is mismatching with the consumer, then we can - // not inline into this position, otherwise the max producer position of - // the consumer will become invalid and expression sort will fail. - if (TransformReplay::getMatchedLeafPosWithoutReplayCasP( - consumer, producer, producer_pos + 1) < 0) { - return producer_pos; - } - auto map_it = p2c_replay_map.find(producer->axis(producer_pos)); - if (map_it != p2c_replay_map.end()) { - auto c_id = map_it->second; - if (!isAllowedID(c_id, consumer, true, false, true)) { - return producer_pos; - } - } - } - return producer->nDims(); -} - -size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) { - auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); - for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { - max_pos = std::min( - max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv)); - } - if (check_siblings) { - for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { - max_pos = std::min(max_pos, getMaxPosAll(sibling_tv, false)); - } - } - return max_pos; -} - -void InlinePropagator::setCAPos(TensorView* tv) { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - size_t pos = mapped_reference_pos_.at(tv); - if (debug) { - std::cout << " Setting CA pos of " << tv << ":" << std::endl; - std::cout << " mapped position: " << pos << std::endl; - } - if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) { - auto max_pos = getMaxPosAll(tv); - if (debug) { - std::cout << " max inlinable position: " << max_pos << std::endl; - } - if (mode_ == ComputeAtMode::Standard) { - TORCH_INTERNAL_ASSERT( - pos <= max_pos, - "Invalid compute at position detected in InlinePropagator when trying to set the CA position of: ", - tv, - " to ", - pos, - ", max position that's allowed is ", - max_pos); - } else if (mode_ == ComputeAtMode::BestEffort) { - pos = std::min(pos, max_pos); - } else { - pos = max_pos; - } - // hoist inner most broadcast - while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { - pos--; - } - auto current_ca_pos = tv->getComputeAtPosition(); - if (debug) { - std::cout << " current CA position: " << current_ca_pos << std::endl; - } - if (pos > current_ca_pos) { - if (debug) { - std::cout << " new CA position: " << pos << std::endl; - } - tv->setComputeAt(pos); - for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { - needs_update_max_producer_.insert(consumer_tv); - } - } else if (debug) { - std::cout << " CA position not changed" << std::endl; - } - } else if (debug) { - std::cout << " tensor not selected, skip" << std::endl; - } -} - -InlinePropagator::InlinePropagator( - TensorView* reference, - int64_t reference_pos, - ComputeAtMode mode, - std::unordered_set selected, - std::unordered_set uninlinable_ids) - : max_pos_calc(mode, std::move(uninlinable_ids)), - selected_(std::move(selected)), - reference_(reference), - mode_(mode) { - if (reference_pos < 0) { - reference_pos += int64_t(reference->nDims()) + 1; - } - TORCH_INTERNAL_ASSERT( - reference_pos >= 0 && reference_pos <= reference->nDims(), - "Invalid computeAt axis, received ", - reference_pos, - " but should be > -", - reference->nDims(), - " and <= ", - reference->nDims(), - "."); - reference_pos_ = reference_pos; -} - -void InlinePropagator::setUp() { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - mapped_reference_pos_[reference_] = reference_pos_; - if (debug) { - std::cout << "InlinePropagator::setUp" << std::endl; - std::cout << " reference: " << reference_ << " @ " << reference_pos_ - << std::endl; - } - setCAPos(reference_); -} - -namespace { - -// Try to find the aligned position on consumer's domain corresponding to the -// compute at position of producer domain. Used in InlinePropagator pass only. -// No checking on actual producer-consumer relationship. -unsigned int getConsumerPosAlignedToProducerCA( - TensorView* consumer, - TensorView* producer) { - // Locate consumer's position that aligns with - // the producer's new compute at axis. We need broadcast axes forwarded so we - // need to replay PasC as CasP will not forward braodcast dims. For example - // if we have: - // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) - // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will - // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to - // NVFuserTest.FusionComplexBCast1_CUDA - - auto disjoint_sets = - BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)) - .getDisjointSets(); - - // Find the innermost position of consumer that has - // been mapped within the producer ca axis. - unsigned int consumer_pos = consumer->nDims(); - while (consumer_pos > 0) { - auto consumer_id = consumer->axis((int)consumer_pos - 1); - auto p_dom = producer->domain()->domain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer->getComputeAtPosition(), - [&consumer_id, &disjoint_sets](IterDomain* p_id) { - return disjoint_sets.permissiveAreMapped(consumer_id, p_id); - })) { - break; - } - consumer_pos--; - } - - return consumer_pos; -} - -} // namespace - -void InlinePropagator::tearDown() { - for (auto consumer : needs_update_max_producer_) { - unsigned int consumer_pos = 0; - for (auto producer : ir_utils::producerTvsOf(consumer)) { - consumer_pos = std::max( - consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer)); - } - consumer->setMaxProducer(consumer_pos); - } -} - -void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - if (debug) { - std::cout << "InlinePropagator::propagateC2P" << std::endl; - std::cout << " from: " << from << std::endl; - std::cout << " to: " << to << std::endl; - } - // Step 1: find mapped_reference_pos_[to] - int from_pos = mapped_reference_pos_.at(from); - auto to_pos = - TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); - if (mode_ == ComputeAtMode::Standard) { - TORCH_CHECK( - to_pos >= 0, - "Unable to propagate CA position from consumer ", - from, - " at ", - from_pos, - " to producer ", - to, - " because this would require replay."); - } else { - // For MostInlined and BestEffort inline propagation, we allow the DAG to - // be not replayed fully consistently. For such case, we just don't inline - // into the mismatched dimension. - while (to_pos < 0) { - from_pos--; - to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( - to, from, from_pos); - } - } - mapped_reference_pos_[to] = to_pos; - // Step 2: set CA position of `to` - setCAPos(to); -} - -void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - if (debug) { - std::cout << "InlinePropagator::propagateP2C" << std::endl; - std::cout << " from: " << from << std::endl; - std::cout << " to: " << to << std::endl; - } - // Step 1: find mapped_reference_pos_[to] - int from_pos = mapped_reference_pos_.at(from); - auto to_pos = - TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); - if (mode_ == ComputeAtMode::Standard) { - TORCH_CHECK( - to_pos >= 0, - "Unable to propagate CA position from producer ", - from, - " at ", - from_pos, - " to consumer ", - to, - " because this would require replay."); - } else { - // For MostInlined and BestEffort inline propagation, we allow the DAG to - // be not replayed fully consistently. For such case, we just don't inline - // into the mismatched dimension. - while (to_pos < 0) { - from_pos--; - to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP( - to, from, from_pos); - } - } - mapped_reference_pos_[to] = to_pos; - // Step 2: set CA position of `to` - setCAPos(to); -} - -void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - if (debug) { - std::cout << "InlinePropagator::propagateSibling" << std::endl; - std::cout << " from: " << from << std::endl; - std::cout << " to: " << to << std::endl; - } - // Step 1: find mapped_reference_pos_[to] - auto from_pos = mapped_reference_pos_.at(from); - TORCH_CHECK( - TransformReplay::fullSelfMatching(to, from), - "Unable to propagate CA position from ", - from, - " to sibling ", - to, - " because this would require replay."); - mapped_reference_pos_[to] = from_pos; - // Step 2: set CA position of `to` - setCAPos(to); -} - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h deleted file mode 100644 index d1bdeebd06d63e..00000000000000 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ /dev/null @@ -1,118 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -class TORCH_CUDA_CU_API MaxPosCalculator { - ComputeAtMode mode_ = ComputeAtMode::Standard; - - // Root domains in producer that's unmappable to any of its consumers - std::unordered_set unmappable_dims_; - - // User set IterDomains to not inline, used in schedulers to avoid inlining - // trivial reductions - std::unordered_set uninlinable_ids_; - - // Iterate through all TVs and collect the dimensions of each TV that don't - // map to all its consumer TVs. - void buildUnmappableDims(); - - // Utility function to return if an id of tv is a valid iter domain to inline - // within. This is used in getMaxPos{PasC,CasP}. Different variations of the - // bool values are used if checking max position of PasC, CasP, or checking - // for a max "self" position. - bool isAllowedID( - IterDomain* id, - TensorView* tv, - bool allow_reduction, - bool allow_vectorize, - bool allow_unmappable) const; - - public: - // Returns the position at which tv can be inlined within. - size_t getMaxPosSelf( - TensorView* tv, - bool allow_reduction, - bool allow_vectorize, - bool allow_unmappable) const; - - // Returns the maximum position producer can be inlined based on consumer - // given the set ComputeAtMode - size_t getMaxProducerPosFromConsumer( - TensorView* producer, - TensorView* consumer) const; - - MaxPosCalculator( - ComputeAtMode mode, - std::unordered_set uninlinable_ids = {}); -}; - -// Propagate inline position to the `selected` tensors in the DAG. If `selected` -// is not specified or empty, then propagate to the entire DAG. -class TORCH_CUDA_CU_API InlinePropagator - : public MaxInfoSpanningTree::Propagator { - // Checks producers and consumers to see what the maximum position in tv is - // that can be shared across both directions. - size_t getMaxPosAll(TensorView* tv, bool check_siblings = true); - - // We use mapped_reference_pos_ to keep track of the outer axes information of - // the reference tensor. That is, mapped_reference_pos_[tv] answers the - // question "What outer axes in tv are shared with the specified reference - // tensor's outer axes?". However, when we actually set the CA position of tv, - // we might not want to set it as mapped_reference_pos_[tv] because because we - // don't want to inline certain things (such as vectorized dimensions, inner - // most broadcasting, etc.). - std::unordered_map mapped_reference_pos_; - - // Actually set the computeAt position. This does not necessarily equal to - // mapped_reference_pos_[tv] because we don't want to inline certain things. - void setCAPos(TensorView* tv); - - const MaxPosCalculator max_pos_calc; - std::unordered_set selected_; - std::unordered_set needs_update_max_producer_; - TensorView* reference_; - size_t reference_pos_; - ComputeAtMode mode_ = ComputeAtMode::Standard; - - public: - InlinePropagator( - TensorView* reference, - int64_t reference_pos, - ComputeAtMode mode = ComputeAtMode::Standard, - std::unordered_set selected = {}, - std::unordered_set uninlinable_ids = {}); - - InlinePropagator( - TensorView* reference, - int64_t reference_pos, - std::unordered_set selected) - : InlinePropagator( - reference, - reference_pos, - ComputeAtMode::Standard, - selected) {} - - ~InlinePropagator() = default; - - // Actually propagate the transformations for the inlining pass. Uses the - // functions above to figure out what position to do the propagation at. - virtual void setUp() override; - virtual void propagateC2P(TensorView* from, TensorView* to) override; - virtual void propagateP2C(TensorView* from, TensorView* to) override; - virtual void propagateSibling(TensorView* from, TensorView* to) override; - virtual void tearDown() override; -}; - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/inlining.cpp b/torch/csrc/jit/codegen/cuda/inlining.cpp new file mode 100644 index 00000000000000..da6d229c68f8b5 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/inlining.cpp @@ -0,0 +1,306 @@ +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +MaxPosCalculator::MaxPosCalculator( + const std::unordered_set& uninlinable_ids) + : uninlinable_ids_(uninlinable_ids) { + buildUnmappableDims(); +} + +void MaxPosCalculator::buildUnmappableDims() { + ComputeAtRootDomainMap root_map; + root_map.build(); + auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); + for (auto tv : all_tvs) { + auto consumers = ir_utils::consumerTvsOf(tv); + for (auto consumer : consumers) { + // Grab dimensions in producer and consumer that are mappable to eachother + // based on the computeAtRootDomainMap. This will tell us which dimensions + // can be inlined based on avoiding trying to inline non-trivial + // reduction structures. + auto mappable_roots = + root_map.getMappableDims(tv->domain(), consumer->domain()); + for (auto tv_root_id : tv->getMaybeRFactorDomain()) { + if (mappable_roots.find(tv_root_id) == mappable_roots.end() && + !tv_root_id->isTrivialReduction()) { + unmappable_dims_.emplace(tv_root_id); + } + } + } + } +} + +bool MaxPosCalculator::isAllowedID( + IterDomain* id, + TensorView* tv, + bool best_effort, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const { + bool allowed = true; + + if (!allow_reduction) { + allowed = allowed && !id->isReduction(); + } + + if (uninlinable_ids_.count(id)) { + return false; + } + + if (!allow_vectorize) { + // Avoid inlining if marked as Vectorize or Group. In the case of + // BestEffort and MostInlined modes, avoid Unroll as well. + bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) || + id->getParallelType() == ParallelType::Group || + (best_effort && id->getParallelType() == ParallelType::Unroll); + allowed = allowed && !is_vectorize; + } + + if (!allow_unmappable) { + auto root_dom = tv->getMaybeRFactorDomain(); + std::unordered_set root_dom_set(root_dom.begin(), root_dom.end()); + auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); + bool is_unmappable = false; + for (auto val : all_vals) { + auto id = val->as(); + if (root_dom_set.count(val) > 0 && unmappable_dims_.count(id) > 0) { + is_unmappable = true; + break; + } + } + allowed = allowed && !is_unmappable; + } + + return allowed; +} + +size_t MaxPosCalculator::getMaxPosSelf( + TensorView* tv, + bool best_effort, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const { + auto dom = tv->domain()->domain(); + auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) { + return !isAllowedID( + id, + tv, + best_effort, + allow_reduction, + allow_vectorize, + allow_unmappable); + }); + return std::distance(dom.begin(), iter); +} + +// Return the max position in producer that can be inlined to consumer +// Cannot inline: +// Vectorized dimensions in consumer +// Unrolled dimensions in consumer +size_t MaxPosCalculator::getMaxProducerPosFromConsumer( + TensorView* producer, + TensorView* consumer, + bool best_effort) const { + auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); + auto replay_CasP = + BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); + auto p2c_replay_map = replay_CasP.getReplay(); + + for (size_t producer_pos = 0; producer_pos < producer->nDims(); + producer_pos++) { + // If the producer position is mismatching with the consumer, then we can + // not inline into this position, otherwise the max producer position of + // the consumer will become invalid and expression sort will fail. + if (TransformReplay::getMatchedLeafPosWithoutReplayCasP( + consumer, producer, producer_pos + 1) < 0) { + return producer_pos; + } + auto map_it = p2c_replay_map.find(producer->axis(producer_pos)); + if (map_it != p2c_replay_map.end()) { + auto c_id = map_it->second; + if (!isAllowedID(c_id, consumer, best_effort, true, false, true)) { + return producer_pos; + } + } + } + return producer->nDims(); +} + +size_t MaxPosCalculator::getMaxPosAll( + TensorView* tv, + bool best_effort, + bool check_siblings) { + auto max_pos = getMaxPosSelf(tv, best_effort, false, false, false); + for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { + max_pos = std::min( + max_pos, getMaxProducerPosFromConsumer(tv, consumer_tv, best_effort)); + } + if (check_siblings) { + for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { + max_pos = std::min( + max_pos, getMaxPosAll(sibling_tv, best_effort, false)); + } + } + return max_pos; +} + +void inlineMost(const std::unordered_set& uninlinable_ids) { + inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids); +} + +void inlineMost( + const std::vector& tvs, + const std::unordered_set& uninlinable_ids) { + if (tvs.empty()) { + return; + } + MaxPosCalculator calc(uninlinable_ids); + for (auto tv : tvs) { + tv->inlineAt(-1, true, &calc); + } +} + +void inlineMost( + const std::unordered_set& tvs, + const std::unordered_set& uninlinable_ids) { + if (tvs.empty()) { + return; + } + MaxPosCalculator calc(uninlinable_ids); + for (auto tv : tvs) { + tv->inlineAt(-1, true, &calc); + } +} + +namespace { + +// Find the positions of `selected` tensors that is mapped to the given position +// in the reference tensor. +class FindMappedPositions : public MaxInfoSpanningTree::Propagator { + std::unordered_map& output_; + + public: + FindMappedPositions( + std::unordered_map& output, + TensorView* reference, + int64_t reference_pos); + + ~FindMappedPositions() = default; + + virtual void propagateC2P(TensorView* from, TensorView* to) override; + virtual void propagateP2C(TensorView* from, TensorView* to) override; + virtual void propagateSibling(TensorView* from, TensorView* to) override; +}; + +FindMappedPositions::FindMappedPositions( + std::unordered_map& output, + TensorView* reference, + int64_t reference_pos) + : output_(output) { + if (reference_pos < 0) { + reference_pos += int64_t(reference->nDims()) + 1; + } + TORCH_CHECK( + reference_pos >= 0 && reference_pos <= reference->nDims(), + "Invalid axis received ", + reference_pos, + " but should be > -", + reference->nDims(), + " and <= ", + reference->nDims(), + "."); + output_[reference] = reference_pos; +} + +void FindMappedPositions::propagateC2P(TensorView* from, TensorView* to) { + int from_pos = output_.at(from); + auto to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); + // If there is no matching position found, we compute the highest matched + // position as the closest approximation + while (to_pos < 0) { + from_pos--; + to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); + } + output_[to] = to_pos; +} + +void FindMappedPositions::propagateP2C(TensorView* from, TensorView* to) { + int from_pos = output_.at(from); + auto to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); + // If there is no matching position found, we compute the highest matched + // position as the closest approximation + while (to_pos < 0) { + from_pos--; + to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); + } + output_[to] = to_pos; +} + +void FindMappedPositions::propagateSibling(TensorView* from, TensorView* to) { + auto from_pos = output_.at(from); + TORCH_CHECK( + TransformReplay::fullSelfMatching(to, from), + "Transformations in siblings ", + from, + " and ", + to, + " does not match with each other."); + output_[to] = from_pos; +} + +std::unordered_map getPositionsMappedTo( + TensorView* reference_tv, + int64_t reference_pos) { + std::unordered_map mapped_positions; + MaxRootDomainInfoSpanningTree tree(reference_tv, reference_pos); + FindMappedPositions propagator(mapped_positions, reference_tv, reference_pos); + tree.traverse(&propagator); + return mapped_positions; +} + +} // namespace + +void inlineAllAt( + TensorView* reference_tv, + int64_t reference_pos, + bool best_effort, + const std::unordered_set& uninlinable_ids) { + auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); + MaxPosCalculator calc(uninlinable_ids); + for (auto pair : mapped_positions) { + pair.first->inlineAt(pair.second, best_effort, &calc); + } +} + +void inlineSelectedAt( + const std::unordered_set& selected, + TensorView* reference_tv, + int64_t reference_pos, + bool best_effort, + const std::unordered_set& uninlinable_ids) { + auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); + MaxPosCalculator calc(uninlinable_ids); + for (auto pair : mapped_positions) { + if (selected.count(pair.first) > 0) { + pair.first->inlineAt(pair.second, best_effort, &calc); + } + } +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/inlining.h b/torch/csrc/jit/codegen/cuda/inlining.h new file mode 100644 index 00000000000000..3b15eb23f98777 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/inlining.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class MaxPosCalculator { + // Root domains in producer that's unmappable to any of its consumers + std::unordered_set unmappable_dims_; + + // User set IterDomains to not inline, used in schedulers to avoid inlining + // trivial reductions + std::unordered_set uninlinable_ids_; + + // Iterate through all TVs and collect the dimensions of each TV that don't + // map to all its consumer TVs. + void buildUnmappableDims(); + + // Utility function to return if an id of tv is a valid iter domain to inline + // within. This is used in getMaxPos{PasC,CasP}. Different variations of the + // bool values are used if checking max position of PasC, CasP, or checking + // for a max "self" position. + bool isAllowedID( + IterDomain* id, + TensorView* tv, + bool best_effort, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const; + + public: + // Returns the position at which tv can be inlined within. + size_t getMaxPosSelf( + TensorView* tv, + bool best_effort, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const; + + // Returns the maximum position producer can be inlined based on consumer + // given the set ComputeAtMode + size_t getMaxProducerPosFromConsumer( + TensorView* producer, + TensorView* consumer, + bool best_effort) const; + + // Checks producers, consumers, and siblings to see what the maximum position + // in tv is that can be shared across both directions. + size_t getMaxPosAll( + TensorView* tv, + bool best_effort = false, + bool check_siblings = true); + + MaxPosCalculator(const std::unordered_set& uninlinable_ids = {}); +}; + +// Inline to the right most allowed position for all tensors in the current +// fusion. +TORCH_CUDA_CU_API void inlineMost( + const std::unordered_set& uninlinable_ids = {}); +// Inline to the right most allowed position for the selected tensors in the +// current fusion. +TORCH_CUDA_CU_API void inlineMost( + const std::vector& tvs, + const std::unordered_set& uninlinable_ids = {}); +// Inline to the right most allowed position for the selected tensors in the +// current fusion. +TORCH_CUDA_CU_API void inlineMost( + const std::unordered_set& tvs, + const std::unordered_set& uninlinable_ids = {}); + +// Inline to the position corresponding to the reference position in the +// reference tensor for all tensors in the current fusion. +TORCH_CUDA_CU_API void inlineAllAt( + TensorView* reference_tv, + int64_t reference_pos, + bool best_effort = false, + const std::unordered_set& uninlinable_ids = {}); + +// Inline to the position corresponding to the reference position in the +// reference tensor for selected tensors in the current fusion. +TORCH_CUDA_CU_API void inlineSelectedAt( + const std::unordered_set& selected, + TensorView* reference_tv, + int64_t reference_pos, + bool best_effort = false, + const std::unordered_set& uninlinable_ids = {}); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 126abba2ae1032..dbefc4858d110c 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -154,8 +154,6 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { //! the compute at position to maximum possible through traversal. enum class ComputeAtMode { Standard, BestEffort, MostInlined }; -class InlinePropagator; -class MaxProducerPosUpdater; class TransformPropagator; struct MostInlinedTransformPropagator; class TransformIter; @@ -163,6 +161,8 @@ class TransformReplay; class OptOutMutator; class TensorDomain; +class MaxPosCalculator; + namespace ir_utils { class TVDomainGuard; } @@ -492,21 +492,30 @@ class TORCH_CUDA_CU_API TensorView : public Val { friend TORCH_CUDA_CU_API MostInlinedTransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; - friend TORCH_CUDA_CU_API InlinePropagator; - friend TORCH_CUDA_CU_API MaxProducerPosUpdater; + friend class InlineBatchingGuard; friend class ir_utils::TVDomainGuard; - friend TORCH_CUDA_CU_API void groupReductions( - const std::vector&); + + // Inline the computation of this tensor into its consumer at the given + // position. If this tensor is already inlined in a higher position, then this + // call is a no-op. If the right most dimensions before `pos` are + // broadcasting, then will not inline into these broadcastings. If + // best_effort, then will inline into the highest allowed position that is <= + // `pos`. + void inlineAt( + int64_t pos, + bool best_effort = false, + MaxPosCalculator* calc = nullptr); + + // Update the max producer position of the current tensor. This is required + // when we modify producer-consumer relationship of a scheduled tensor, for + // example, grouping multiple reductions. + void updateMaxProducerPosition(); protected: void setDomain(TensorDomain* td) { domain_ = td; } - void setComputeAt(unsigned int this_pos, bool decrease = false); - - void setMaxProducer(unsigned int this_pos, bool decrease = false); - private: int normalizeAxisPos(int pos) const { if (pos < 0) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index bd887a9a1754a6..b40e6fbf7cf7a5 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include @@ -805,9 +805,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // get a higher position in later inline propagation. We need this separate // step because we were not using ParallelType::Unroll, so we have to do // unrolling manually. - InlinePropagator inline_unswitch( - reference_tv, unswitch_pos, ComputeAtMode::BestEffort); - spanning_tree.traverse(&inline_unswitch); + inlineAllAt(reference_tv, unswitch_pos, true); auto all_tvs = ir_utils::allTvs(fusion); @@ -822,9 +820,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { auto output = entry.second; inner_most_tensors.erase(output); } - InlinePropagator inline_inner_most( - reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors); - spanning_tree.traverse(&inline_inner_most); + inlineMost(inner_most_tensors); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp index 6bd4d4efba3767..ae9ecd88bbdc3c 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include @@ -336,14 +336,7 @@ void multiReductionInliner( scheduler_utils::getTrivialReductionMap(fusion); // Inline the schedule - InlinePropagator inline_propagator( - reference_tv, - -1, - ComputeAtMode::MostInlined, - {}, - mapped_to_trivial_reduction); - - MaxRootDomainInfoSpanningTree(reference_tv).traverse(&inline_propagator); + inlineMost(mapped_to_trivial_reduction); } namespace { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp index 28df2bece60799..b7e85cbc1c5e77 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include @@ -1131,9 +1131,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { } // Inline - InlinePropagator inline_propagator( - reference1, -1, ComputeAtMode::MostInlined); - entire_dag.traverse(&inline_propagator); + inlineMost(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index ba95d8fabdce99..633c98102e2e0d 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -290,40 +291,115 @@ IterDomain* TensorView::axis(int pos) const { return domain()->axis(pos); } -void TensorView::setComputeAt(unsigned int pos, bool decrease) { +void TensorView::inlineAt( + int64_t pos, + bool best_effort, + MaxPosCalculator* calc) { TORCH_INTERNAL_ASSERT( !container()->isA(), "Function invalid for kernel container."); - if (pos <= compute_at_pos_ && !decrease) { - return; + + std::unique_ptr calc_owner; + if (calc == nullptr) { + calc_owner = std::make_unique(); + calc = calc_owner.get(); + } + + if (pos < 0) { + pos += int64_t(nDims()) + 1; } TORCH_INTERNAL_ASSERT( - (unsigned)pos <= nDims(), - "Invalid this computeAt position for T", + pos >= 0 && pos <= nDims(), + "Invalid inline position for T", name(), ": ", pos); - compute_at_pos_ = pos; -} + auto max_inline_pos = calc->getMaxPosAll(this, best_effort); -void TensorView::setMaxProducer(unsigned int pos, bool decrease) { - TORCH_INTERNAL_ASSERT( - !container()->isA(), - "Function invalid for kernel container."); - if (pos <= max_producer_pos_ && !decrease) { - return; + if (best_effort) { + pos = std::min(max_inline_pos, pos); + } + + // hoist inner most broadcast + while (pos > 0 && axis(pos - 1)->isBroadcast()) { + pos--; } TORCH_INTERNAL_ASSERT( - (unsigned)pos <= nDims(), - "Invalid max producer position for T", + pos <= max_inline_pos, + "Invalid inline position for T", name(), ": ", - pos); + pos, + ". Maximum allowed value:", + max_inline_pos); + + if (isFusionInput()) { + return; + } + + if (pos > compute_at_pos_) { + compute_at_pos_ = pos; + for (auto consumer : ir_utils::consumerTvsOf(this)) { + consumer->updateMaxProducerPosition(); + } + } +} + +namespace { + +// Try to find the aligned position on consumer's domain corresponding to the +// compute at position of producer domain. No checking on actual +// producer-consumer relationship. +unsigned int getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer) { + // Locate consumer's position that aligns with + // the producer's new compute at axis. We need broadcast axes forwarded so we + // need to replay PasC as CasP will not forward braodcast dims. For example + // if we have: + // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) + // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will + // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to + // NVFuserTest.FusionComplexBCast1_CUDA + + auto disjoint_sets = + BestEffortReplay::replayPasC( + producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)) + .getDisjointSets(); + + // Find the innermost position of consumer that has + // been mapped within the producer ca axis. + unsigned int consumer_pos = consumer->nDims(); + while (consumer_pos > 0) { + auto consumer_id = consumer->axis((int)consumer_pos - 1); + auto p_dom = producer->domain()->domain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer->getComputeAtPosition(), + [&consumer_id, &disjoint_sets](IterDomain* p_id) { + return disjoint_sets.permissiveAreMapped(consumer_id, p_id); + })) { + break; + } + consumer_pos--; + } + + return consumer_pos; +} + +} // namespace - max_producer_pos_ = pos; +void TensorView::updateMaxProducerPosition() { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); + for (auto producer : ir_utils::producerTvsOf(this)) { + max_producer_pos_ = std::max( + max_producer_pos_, getConsumerPosAlignedToProducerCA(this, producer)); + } } TensorView* TensorView::computeAt( diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 1ff36c1a5d29f7..ee5e55bd592e17 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include #include #include @@ -11522,7 +11522,7 @@ TEST_F(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { fusion.addInput(tv1); fusion.addInput(tv2); - auto tv3 = broadcast(tv0, {false, true}); + auto tv3 = broadcast(tv0, {true, false}); auto tv4 = add(tv3, tv1); auto tv5 = add(tv3, tv2); @@ -25084,7 +25084,7 @@ TEST_F( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims1_CUDA) { +TEST_F(NVFuserTest, FusionInliningMismatchedDims1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -25097,8 +25097,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims1_CUDA) { auto tv5 = tan(tv4); fusion.addOutput(tv5); - InlinePropagator inline_propagator(tv5, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv5).traverse(&inline_propagator); + inlineMost(); TORCH_CHECK(tv5->getComputeAtPosition() == 3); TORCH_CHECK(tv4->getComputeAtPosition() == 3); @@ -25118,7 +25117,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims1_CUDA) { testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims2_CUDA) { +TEST_F(NVFuserTest, FusionInliningMismatchedDims2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -25131,8 +25130,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims2_CUDA) { auto tv5 = tan(tv4); fusion.addOutput(tv5); - InlinePropagator inline_propagator(tv5, -1, ComputeAtMode::BestEffort); - MaxRootDomainInfoSpanningTree(tv5).traverse(&inline_propagator); + inlineAllAt(tv5, -1, true); TORCH_CHECK(tv5->getComputeAtPosition() == 3); TORCH_CHECK(tv4->getComputeAtPosition() == 3); @@ -25152,7 +25150,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims2_CUDA) { testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims3_CUDA) { +TEST_F(NVFuserTest, FusionInliningMismatchedDims3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -25176,8 +25174,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims3_CUDA) { tv->merge(2); } - InlinePropagator inline_propagator(tv8, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv8).traverse(&inline_propagator); + inlineMost(); TORCH_CHECK(tv8->getComputeAtPosition() == 3); TORCH_CHECK(tv7->getComputeAtPosition() == 3); @@ -25200,7 +25197,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims3_CUDA) { testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims4_CUDA) { +TEST_F(NVFuserTest, FusionInliningMismatchedDims4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -25214,8 +25211,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims4_CUDA) { fusion.addOutput(tv5); tv3->merge(1); - InlinePropagator inline_propagator(tv0, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv0).traverse(&inline_propagator); + inlineMost(); TORCH_CHECK(tv5->getComputeAtPosition() == 3); TORCH_CHECK(tv4->getComputeAtPosition() == 3); @@ -25235,7 +25231,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims4_CUDA) { testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionInliningBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -25254,8 +25250,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) { tv->merge(2); } - InlinePropagator inline_propagator(tv0, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv0).traverse(&inline_propagator); + inlineMost(); TORCH_CHECK(tv4->getComputeAtPosition() == 3); TORCH_CHECK(tv3->getComputeAtPosition() == 3); @@ -25274,7 +25269,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) { testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionInlinePropagatorBroadcastTrivialReduction_CUDA) { +TEST_F(NVFuserTest, FusionInliningBroadcastTrivialReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -25296,8 +25291,7 @@ TEST_F(NVFuserTest, FusionInlinePropagatorBroadcastTrivialReduction_CUDA) { tv->merge(2); } - InlinePropagator inline_propagator(tv6, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv6).traverse(&inline_propagator); + inlineMost(); TORCH_CHECK(tv6->getComputeAtPosition() == 3); TORCH_CHECK(tv5->getComputeAtPosition() == 3); @@ -25387,8 +25381,7 @@ TEST_F(NVFuserTest, FusionIdGraphTrivialReduction_CUDA) { tv->merge(2); } - InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv3).traverse(&inline_propagator); + inlineMost(); ComputeAtMap ca_map(&fusion); @@ -25654,8 +25647,7 @@ TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) { tv->axis(-1)->parallelize(ParallelType::TIDx); } - InlinePropagator propagator(tv2, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + inlineMost(); auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({5, 5}, options); @@ -25746,8 +25738,7 @@ TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction1_CUDA) { TransformPropagatorWithCheck tp(tv0); tree.traverse(&tp); - InlinePropagator ip(tv0, -1, ComputeAtMode::MostInlined); - tree.traverse(&ip); + inlineMost(); auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1, 1}, options); @@ -25782,8 +25773,7 @@ TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction2_CUDA) { TransformPropagatorWithCheck tp(tv0); tree.traverse(&tp); - InlinePropagator ip(tv0, -1, ComputeAtMode::MostInlined); - tree.traverse(&ip); + inlineMost(); auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10, 1, 1}, options); @@ -25986,6 +25976,30 @@ TEST_F(NVFuserTest, FusionMappingRelation_CUDA) { fusion, {out}, {t0, t1}, {t1 + t0.squeeze(0)}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionInlineAt_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = cos(tv1); + fusion->addOutput(tv2); + + tv1->inlineAt(-1); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({100, 2}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + auto out = cg_outputs[0]; + + testValidate(fusion, {out}, {t0}, {t0.sin().cos()}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 3b9e7cbd962c65..e827de56e56bdd 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -2391,10 +2391,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelfordShmoo_CUDA) { transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Unswitch); - InlinePropagator inline_propagator( - transform_ref_rf, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(transform_ref_rf) - .traverse(&inline_propagator); + inlineMost(); // Make sure the reduction expr is converted to GroupedGridReduciton // and the non-reduction domains of the output TV are either diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index 252dfa00ef265a..ad43b0ed4e07d2 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include @@ -750,9 +750,7 @@ TEST_F(NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) { } // inline - MaxRootDomainInfoSpanningTree entire_dag(tv9); - InlinePropagator inline_propagator(tv9, -1, ComputeAtMode::MostInlined); - entire_dag.traverse(&inline_propagator); + inlineMost(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input0 = at::randn({512, 1024, 256}, options); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp index 3cb5f9d985a7fd..1ed73d3256bcf9 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -1390,8 +1390,7 @@ TEST_F(NVFuserTest, FusionPwiseViewSchedule_CUDA) { scheduler_utils::parallelizeAllLike(tv5); // Inline the schedule - InlinePropagator inline_propagator(tv5, -1, ComputeAtMode::MostInlined); - spanning_tree.traverse(&inline_propagator); + inlineMost(); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -1457,8 +1456,7 @@ TEST_F(NVFuserTest, FusionSumViewSchedule_CUDA) { scheduler_utils::parallelizeAllLike(tv5_rf); // Inline the schedule - InlinePropagator inline_propagator(tv5_rf, -1, ComputeAtMode::MostInlined); - spanning_tree.traverse(&inline_propagator); + inlineMost(); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -1763,8 +1761,7 @@ TEST_F(NVFuserTest, FusionViewMapping_CUDA) { scheduler_utils::parallelizeAllLike(tv6); // Inline the schedule - InlinePropagator inline_propagator(tv6, -1, ComputeAtMode::MostInlined); - spanning_tree.traverse(&inline_propagator); + inlineMost(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -1805,8 +1802,7 @@ TEST_F(NVFuserTest, FusionLowerDivisibleSplits_CUDA) { scheduler_utils::parallelizeAllLike(tv2); // Inline the schedule - InlinePropagator inline_propagator(tv2, -1, ComputeAtMode::MostInlined); - spanning_tree.traverse(&inline_propagator); + inlineMost(); auto divisible_splits = getAllDivisibleSplits(&fusion); diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index a79c4d2db83ad9..101aac58479ad0 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -38,7 +38,6 @@ auto parseDebugDumpOptions() { {DebugDumpOption::Halo, false}, {DebugDumpOption::PerfDebugVerbose, false}, {DebugDumpOption::TransformPropagator, false}, - {DebugDumpOption::InlinePropagator, false}, {DebugDumpOption::Cubin, false}, {DebugDumpOption::Ptx, false}}; @@ -91,8 +90,6 @@ auto parseDebugDumpOptions() { options_map[DebugDumpOption::PerfDebugVerbose] = true; } else if (token == "transform_propagator") { options_map[DebugDumpOption::TransformPropagator] = true; - } else if (token == "inline_propagator") { - options_map[DebugDumpOption::InlinePropagator] = true; } else if (token == "cubin") { options_map[DebugDumpOption::Cubin] = true; } else if (token == "ptx") { @@ -108,7 +105,7 @@ auto parseDebugDumpOptions() { "\tkernel_args, dump_eff_bandwidth, draw_segmented_fusion,\n", "\tscheduler_params, parallel_dimensions, buffer_reuse_verbose,\n", "\tptxas_verbose, halo, segmenter_logging, perf_debug_verbose\n", - "\ttransform_propagator, inline_propagator, cubin, ptx\n"); + "\ttransform_propagator, cubin, ptx\n"); } options_view = (end_pos != c10::string_view::npos) ? options_view.substr(end_pos + 1) diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index c8205b88e3f972..5e69ac2bb22b9c 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -54,8 +54,6 @@ enum class DebugDumpOption { //! associated with what's running TransformPropagator, //! When running TransformPropagator, print propagation //! path and replay result - InlinePropagator, //! When running InlinePropagator, print propagation - //! path and inlining result Cubin, //! Dump compiled CUBIN Ptx //! Dump compiled PTX };