diff --git a/CMakeLists.txt b/CMakeLists.txt index 853b55ef4f7..a76988abd1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,14 @@ set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party") option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF) option(NVFUSER_BUILD_WITH_ASAN "Build nvFuser with asan" OFF) +include(CMakeDependentOption) +cmake_dependent_option(NVFUSER_DISTRIBUTED "" ON + "USE_DISTRIBUTED" OFF) +if (NVFUSER_DISTRIBUTED) + add_compile_definitions(NVFUSER_DISTRIBUTED) +endif() +message(STATUS "Setting NVFUSER_DISTRIBUTED=${NVFUSER_DISTRIBUTED}") + if(NOT NVFUSER_CPP_STANDARD) set(NVFUSER_CPP_STANDARD 20) endif() @@ -198,6 +206,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp ${NVFUSER_SRCS_DIR}/val_graph.cpp + ${NVFUSER_SRCS_DIR}/val_graph_visitor.cpp ) # We don't link CUPTI for MSVC @@ -639,6 +648,7 @@ message(STATUS "******** Nvfuser configuration summary ********") message(STATUS " UCC_FOUND: ${UCC_FOUND}") message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") message(STATUS " NVFUSER_BUILD_WITH_ASAN : ${NVFUSER_BUILD_WITH_ASAN}") +message(STATUS " NVFUSER_DISTRIBUTED : ${NVFUSER_DISTRIBUTED}") message(STATUS " NVFUSER_CPP_STANDARD : ${NVFUSER_CPP_STANDARD}") if(NVFUSER_STANDALONE_BUILD_WITH_UCC) diff --git a/benchmark/matmul.cpp b/benchmark/matmul.cpp index 4a1cc265f85..068a17f393d 100644 --- a/benchmark/matmul.cpp +++ b/benchmark/matmul.cpp @@ -55,17 +55,21 @@ void setupMatmul( auto c = matmul(a, b, layout, turing_or_later); + // Cast the output so that we perform an HSH matmul, which is what at::matmul + // will perform + auto d = castOp(DataType::Half, c); + fusion->addInput(a); fusion->addInput(b); - fusion->addOutput(c); + fusion->addOutput(d); scheduleMatmul(fusion, params); } void checkMatch(at::Tensor expect, at::Tensor result, int64_t k) { // tolerance - double rtol = 1e-6 * k; - double atol = 1e-6 * k; + double rtol = 1e-4 * k; + double atol = 1e-4 * k; auto ndim = result.ndimension(); auto is_close = at::isclose(expect, result, rtol, atol); @@ -145,7 +149,7 @@ static void SingleMatmulBase( int64_t k = benchmark_state.range(2); // Tensor inputs - auto inputs = matmulAtInput(m, n, k, layout); + auto inputs = matmulAtInput2D(m, n, k, layout); auto expected_output = atMatmul( inputs.first.to(at::kDouble), inputs.second.to(at::kDouble), layout); @@ -203,10 +207,16 @@ static void Baseline_Matmul( benchmark_state.range(1), benchmark_state.range(2)}; + bool allow_half_reduction = (bool)benchmark_state.range(3); + at::manual_seed(0); - auto inputs = - matmulAtInput(input_mnk.at(0), input_mnk.at(1), input_mnk.at(2), layout); + auto inputs = matmulAtInput2D( + input_mnk.at(0), input_mnk.at(1), input_mnk.at(2), layout); + + // Disable reduced-precision reduction for fair comparison since we do not use + // it in nvFuser + at::globalContext().setAllowFP16ReductionCuBLAS(allow_half_reduction); // warm up run auto outputs = atMatmul(inputs.first, inputs.second, layout); @@ -328,9 +338,9 @@ static void SingleMatmulPartitionedK( scheduleMatmul(fusion, params); - at::Tensor aten_a = matmulAtInput( + at::Tensor aten_a = matmulAtInput2D( layout, TensorMatmulPos::A, at::kHalf, M, N, Ki, splitk_factor); - at::Tensor aten_b = matmulAtInput( + at::Tensor aten_b = matmulAtInput2D( layout, TensorMatmulPos::B, at::kHalf, M, N, Ki, splitk_factor); std::vector aten_inputs = {aten_a, aten_b}; at::Tensor expected_output = splitkLikeAtMatmul( @@ -496,7 +506,24 @@ static void NvFuserScheduler_MatmulSplitKReduction( {784, 72, 8}, \ {784, 8, 72}, \ /* {1, 1, 2048}, */ \ - {1024, 1024, 1024} \ + {1024, 1024, 1024}, \ + /* NanoGPT bwd sizes */ \ + {1024, 2048, 4096}, \ + {1024, 2048, 50304} \ + } + +#define SplitKSpecificShapes \ + { \ + /* NanoGPT bwd sizes */ \ + {1024, 2048, 4096}, \ + {1024, 2048, 50304}, \ + /* Symmetric M,N to make comparison in TN/NT fair with eager due to transpose/swap */ \ + {1024, 1024, 4096}, \ + {1024, 1024, 50304}, \ + /* Sizes mentioned by Michel */ \ + {136, 184, 175704}, \ + /* Other */ \ + {128, 128, 262144} \ } // clang-format on @@ -527,7 +554,7 @@ static std::vector splitKNs(long int tileN = 128) { #define NumWarps \ { 4, 8 } #define NumStages \ - { 3, 4 } + { 3, 4, 5 } //! Simple cartesian product of three integers. Used to emulate ArgsProduct template @@ -560,6 +587,20 @@ static std::vector> sizeProduct( return sizes; } +// Use this to apply shape arguments to a benchmark without additional +// NVFuser-specific args. Used for eager benchmarks to avoid redundant +// benchmarks for combinations of num_warps and num_stages +static void MatmulShapeEager( + benchmark::internal::Benchmark* b, + std::vector> sizes) { + b->ArgNames({"M", "N", "K", "half_reduction"}); + for (auto [m, n, k] : sizes) { + for (bool allow_half_reduction : {false, true}) { + b->Args({m, n, k, allow_half_reduction}); + } + } +} + // Use this to apply shape arguments to a benchmark without additional // NVFuser-specific args. Used for eager benchmarks to avoid redundant // benchmarks for combinations of num_warps and num_stages @@ -603,13 +644,29 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { } } +// Use this for manual splitk. +static void MatmulShapeWarpStageSpecificSplitK( + benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K", "warps", "stages", "splitk_factor"}); + for (long int num_warps : NumWarps) { + for (long int num_stages : NumStages) { + for (auto [m, n, k] : + std::vector>(SplitKSpecificShapes)) { + for (auto splitk_factor : {2, 3, 4, 5, 6}) { + b->Args({m, n, k, num_warps, num_stages, splitk_factor}); + } + } + } + } +} + #define EagerModeBenchmark(layout) \ BENCHMARK_CAPTURE( \ Baseline_Matmul, eagermode_legacyshapes_##layout, MmaLayout::layout) \ ->Unit(benchmark::kMicrosecond) \ ->UseManualTime() \ ->Apply([](benchmark::internal::Benchmark* b) { \ - return MatmulShape( \ + return MatmulShapeEager( \ b, sizeProduct(LegacyMs, LegacyNs, LegacyKs)); \ }); \ BENCHMARK_CAPTURE( \ @@ -617,14 +674,14 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { ->Unit(benchmark::kMicrosecond) \ ->UseManualTime() \ ->Apply([](benchmark::internal::Benchmark* b) { \ - return MatmulShape(b, TIMMShapes); \ + return MatmulShapeEager(b, TIMMShapes); \ }); \ BENCHMARK_CAPTURE( \ Baseline_Matmul, eagermode_splitkshapes_##layout, MmaLayout::layout) \ ->Unit(benchmark::kMicrosecond) \ ->UseManualTime() \ ->Apply([](benchmark::internal::Benchmark* b) { \ - return MatmulShape( \ + return MatmulShapeEager( \ b, sizeProduct(SplitKMs, splitKNs(), SplitKKs)); \ }); @@ -683,9 +740,27 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { ->UseManualTime() \ ->Apply(MatmulShapeWarpStageAutoSplitK); +static void NvFuserScheduler_Matmul_Manual( + benchmark::State& benchmark_state, + MmaLayout layout) { + int splitk_factor = benchmark_state.range(5); + NvFuserScheduler_Matmul( + benchmark_state, layout, splitk_factor, /*partitionedk=*/false); +} + +#define SpecificSplitKBenchmark(layout) \ + BENCHMARK_CAPTURE( \ + NvFuserScheduler_Matmul_Manual, \ + nvfuser_splitk_##layout, \ + MmaLayout::layout) \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime() \ + ->Apply(MatmulShapeWarpStageSpecificSplitK); + ForAllLayouts(EagerModeBenchmark); ForAllLayouts(NvfuserMatmulBenchmark); ForAllLayouts(AutoSplitKBenchmark); +ForAllLayouts(SpecificSplitKBenchmark); ForAllLayouts(AutoPartitionedKBenchmark); // Note: SplitK Reduction benchmarks are parametrized only by M, N. The splitk diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 9b68cf4fdb5..6e11cf93d89 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1693,11 +1693,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { block_flags, ArgumentBuilder().arg("gridDim")); + int64_t vectorize_size = ir_utils::getVectorizeSize(out->view()); + + ArgumentBuilder template_args; + template_args.arg("/*vec_size=*/").append(std::to_string(vectorize_size)); + ArgumentBuilder func_args(block_nest_level_ + 1, kTab); - func_args.arg(gen(out)); - func_args.arg(gen(grop->in())); + func_args.arg("&").append(gen(out)); + func_args.arg("&").append(gen(grop->in())); func_args.arg(gen(grop->init())); - func_args.arg(gen(grop->serialReductionTensor())); + func_args.arg("&").append(gen(grop->serialReductionTensor())); func_args.arg(genReductionOp(op_type, out->dtype())); // Whether this is the first or last step @@ -1720,7 +1725,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(read_pred); } - indent() << "reduction::serialReductionStep(\n"; + indent() << "reduction::serialReductionStep<" << template_args << ">(\n"; indent() << kTab << func_args << ");\n"; } diff --git a/csrc/device_lower/pass/expr_sort.cpp b/csrc/device_lower/pass/expr_sort.cpp index 738a544aaa4..a18002cbf16 100644 --- a/csrc/device_lower/pass/expr_sort.cpp +++ b/csrc/device_lower/pass/expr_sort.cpp @@ -1516,7 +1516,7 @@ void ExprSegmentationSorter::sort() { std::back_inserter(non_pointer_arithmetic_outs), [this](Val* out) { return fusion_->getOutputAlias(out).type != - AliasType::PointerArithmetic; + AllocationType::PointerArithmetic; }); // Not putting the exprs between fusion inputs and allKnownVals() here diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index d95aaeae4a8..31b3dfbe31d 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -477,7 +477,12 @@ class VectorizeValidator : public OptInDispatch { tv, (int)vector_word_size); } - auto producer_tv = tv->definition()->inputs().at(0)->as(); + auto tv_def = tv->definition(); + NVF_ERROR( + tv_def != nullptr, + "Tv has no definition, cannot validate vectorization:", + tv); + auto producer_tv = tv_def->inputs().at(0)->as(); auto producer_word_size_it = GpuLower::current()->vectorizedAccesses().find(producer_tv); if (producer_word_size_it != @@ -579,9 +584,11 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) { } } if (has_vectorize_dim) { + Expr* def = tv->definition(); NVF_ERROR( - tv->definition() == nullptr || tv->definition()->isA() || - tv->definition()->isA(), + def == nullptr || def->isA() || def->isA() || + (def->isA() && + def->as()->serialGridReductionRequested()), "Vectorized accesses cannot be inline with computation, they are only supported with a Set operation.", "TensorView: ", tv); diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 25f4c183af0..d1af78fef7f 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -242,6 +242,14 @@ class VectorOfUniqueEntries { return vector_.end(); } + T& at(size_t pos) { + return vector_.at(pos); + } + + const T& at(size_t pos) const { + return vector_.at(pos); + } + std::string toString() const { std::stringstream ss; ss << "{ "; diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index b23e5a781ef..72a3ec66aa4 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -647,8 +647,35 @@ void DynamicTransformConcretizer::concretizeReshape() { // Extent expressions often change when concretizing a reshape. Here we // replace these in all downstream expressions so that the Fusion looks just // like it would have if we had used a static reshape instead. + // + // Note that Reduction IterDomains might be present in the concretized + // reshape. For example, suppose we are given the following dynamic Fusion + // + // Inputs: + // T0 + // Outputs: + // T3 + // T1[ iS2{i0} rS3{i1} ] = sum(T0[ iS0{i0} iS1{i1} ]) + // T2[ ?S4{i2} ] = view(T1[ iS2{i0} rS3{i1} ]) + // T3[ ?S4{i2} ] = -T2[ ?S4{i2} ] + // + // Then we will concretize this as + // + // Inputs: + // T0 + // Outputs: + // T3 + // T1[ iS2{i0} rS3{i1} ] = sum(T0[ iS0{i0} iS1{i1} ]) + // T3[ iS4{i0} ] = -T1[ iS2{i0} rS3{i1} ] + // + // Notice here that the ViewOp is gone since we recognized that there is no + // transformation to perform. Instead, T1 is used directly in place of T2. + // We also replace the extent i2 from the dynamic reshape output T2 with i0, + // which is what the code below implements. Since T1 includes a Reduction + // IterDomain, we must ignore it in order to match ?S4{i2} with iS2{i0}. auto old_rfactor = incomplete_out_tv->getMaybeRFactorDomain(); - auto new_rfactor = concrete_reshape_out_tv->getMaybeRFactorDomain(); + auto new_rfactor = TensorDomain::noReductions( + concrete_reshape_out_tv->getMaybeRFactorDomain()); NVF_ERROR( old_rfactor.size() == new_rfactor.size(), "Concretized reshape rfactor size does not match symbolic rfactor"); diff --git a/csrc/executor.cpp b/csrc/executor.cpp index b9abf2f48c1..fc696af98b0 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -330,7 +330,7 @@ void FusionExecutor::compileFusion( lowered_ = std::make_unique(fusion, compile_params); lowered_->run(); - const auto kernel = lowered_->kernel(); + kir::Kernel* kernel = lowered_->kernel(); for (const auto& hook : post_lowering_hooks_) { hook(kernel); } @@ -378,7 +378,7 @@ void FusionExecutor::compileFusion( structured_code = getStructuredCode(); } - const auto& kernel_summary = kernel->summary(); + const kir::KernelSummary& kernel_summary = kernel->summary(); // We currently shouldn't allocate any more shared mem // tensors statically but could keep this path if @@ -942,7 +942,7 @@ at::Tensor allocateOutput( return ee.evaluate(out_tv).as(); } - if (alias_info.type == AliasType::NoAlias) { + if (alias_info.type == AllocationType::NoAlias) { auto alloc_tensor = at::native::empty_strided_cuda( out_info.sizes, out_info.strides, @@ -959,7 +959,7 @@ at::Tensor allocateOutput( Val* aliased_io = alias_info.aliased_io; NVF_ERROR( aliased_io != nullptr, - "The other two AliasTypes currently must have an `aliased_io`."); + "The other two AllocationTypes currently must have an `aliased_io`."); NVF_ERROR( aliased_io->isFusionInput() || aliased_io->isFusionOutput(), aliased_io->toInlineString(), @@ -973,8 +973,8 @@ at::Tensor allocateOutput( PolymorphicValue_functions::toString(aliased_io_val)); auto aliased_io_tensor = aliased_io_val.as(); - if (alias_info.type == AliasType::InplaceUpdate) { - // Unlike for `AliasType::PointerArithmetic`, don't use + if (alias_info.type == AllocationType::InplaceUpdate) { + // Unlike for `AllocationType::PointerArithmetic`, don't use // ExpressionEvaluator to compute the output tensor. This is because // the output tensor may hold different data from the input, e.g., an // updated running mean. `ExpressionEvaluator::evaluate(out_tv)` @@ -982,7 +982,7 @@ at::Tensor allocateOutput( return aliased_io_tensor; } - NVF_ERROR(alias_info.type == AliasType::PointerArithmetic); + NVF_ERROR(alias_info.type == AllocationType::PointerArithmetic); at::Tensor out_tensor = ee.evaluate(out_tv).as(); NVF_ERROR( out_tensor.is_alias_of(aliased_io_tensor), @@ -1030,8 +1030,9 @@ std::vector allocateOutputs( const std::pair& lhs, const std::pair& rhs) { return ( - kernel->getOutputAlias(lhs.second).type == AliasType::NoAlias && - kernel->getOutputAlias(rhs.second).type != AliasType::NoAlias); + kernel->getOutputAlias(lhs.second).type == + AllocationType::NoAlias && + kernel->getOutputAlias(rhs.second).type != AllocationType::NoAlias); }); std::vector out_tensors(num_outs); @@ -1819,7 +1820,7 @@ std::vector FusionExecutor::runFusion( timer.init(); } - if (execute_kernel_) { + if (execute_kernel_ && !kernel()->topLevelExprs().empty()) { ensureAvailableDynamicSmemSize(executor_entry->launch_params.smem()); std::vector arg_buffer_ptrs; @@ -1890,7 +1891,8 @@ std::vector FusionExecutor::runFusion( outputBytesProcessed(outputs); - if (isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth)) { + if (isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth) || + isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) { double gb_per_s = ((double)bytesProcessed() / ((double)kernel_time_ms_ / 1000)) / (double)1.0e9; diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 322fa910844..7d939da9088 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -63,7 +63,7 @@ void swap(Fusion& a, Fusion& b) noexcept { std::unique_ptr Fusion::segment( const KernelArgumentHolder& args) { FUSER_PERF_SCOPE("Segment Fusion"); - return SegmentCandidateFinder::segment(this, args); + return SegmentCandidateFinder::segment(this, &args); } IrCloner Fusion::copy(const Fusion* from, Fusion* to) { @@ -764,12 +764,15 @@ bool Fusion::isAliasCompatible(Val* left, Val* right) { return true; } -void Fusion::aliasOutputToInput(Val* output, Val* input, const AliasType type) { +void Fusion::aliasOutputToInput( + Val* output, + Val* input, + const AllocationType type) { NVF_CHECK( - type != AliasType::NoAlias, + type != AllocationType::NoAlias, "NoAlias is returned automatically for a missing key. Don't add it explicitly."); - if (type == AliasType::InplaceUpdate) { + if (type == AllocationType::InplaceUpdate) { // `input` can be a cast of a fusion input. if (!input->isFusionInput()) { auto input_expr = input->definition(); @@ -807,9 +810,11 @@ void Fusion::aliasOutputToInput(Val* output, Val* input, const AliasType type) { } } -const AliasInfo& Fusion::getOutputAlias(Val* output) const { +const AliasInfo& Fusion::getOutputAlias(const Val* output) const { static AliasInfo no_alias_info{ - .type = AliasType::NoAlias, .aliased_io = nullptr, .hide_output = false}; + .type = AllocationType::NoAlias, + .aliased_io = nullptr, + .hide_output = false}; if (auto search = io_alias_.find(output); search != io_alias_.end()) { return search->second; } diff --git a/csrc/fusion.h b/csrc/fusion.h index 08353db2107..e190debb4fc 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -86,7 +86,7 @@ class FusionGuard { // Set the enum base to `int` so it can be safely serialized as a part of // serde::InputOutputAlias. -enum class AliasType : int { +enum class AllocationType : int { NoAlias, // For example, the tensor storing BatchNorm's running mean. The output EMA is // updated in place. @@ -98,7 +98,7 @@ enum class AliasType : int { }; struct AliasInfo { - AliasType type; + AllocationType type; Val* aliased_io; // Whether integration should hide the output from users. This is currently // only used for InplaceUpdate. @@ -248,12 +248,12 @@ class Fusion : public IrContainer { // the input tensor to the section where output is produced. Currently, // aliases of type `PointerArithmetics` are marked after segmentation, but // those of type `InplaceUpdate` are marked in fusion definitions. - void aliasOutputToInput(Val* output, Val* input, AliasType type); + void aliasOutputToInput(Val* output, Val* input, AllocationType type); //! Returns the aliased input of a given output along with an `AliasInfo` //! describing how they alias. Returns when `output` is not //! aliased. - const AliasInfo& getOutputAlias(Val* output) const; + const AliasInfo& getOutputAlias(const Val* output) const; // mark input at index to be permuted by permutation void setPermutationOnInput(int index, std::vector permutation) { @@ -474,7 +474,7 @@ class Fusion : public IrContainer { std::vector outputs_; // io alias pointing from output to input - std::unordered_map io_alias_; + std::unordered_map io_alias_; // See Note [ Permutation support in nvfuser ] // map from indices of input tensor to permutation diff --git a/csrc/fusion_profiler.cpp b/csrc/fusion_profiler.cpp index 81c3cce6bb3..929302dca45 100644 --- a/csrc/fusion_profiler.cpp +++ b/csrc/fusion_profiler.cpp @@ -407,6 +407,10 @@ std::array column_strs{ std::ostream& operator<<(std::ostream& os, const FusionProfile& fp) { if (fp.fusion_id == 0) { + // `os` may have leftover characters in the line + // before the header is printed. So we start with a newline. + os << std::endl; + os << std::left << std::setw(5) << std::get<0>(column_strs) << " " << std::setw(5) << std::get<1>(column_strs) << " " << std::setw(11) << std::get<2>(column_strs) << " " << std::setw(9) @@ -628,15 +632,11 @@ void FusionProfiler::stop() { NVFUSER_CUPTI_SAFE_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER)); NVFUSER_CUPTI_SAFE_CALL( cuptiActivityDisable(CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION)); + // This will be populated by the following `cuptiActivityFlushAll` call. fp->kernel_profiles_.reserve(fp->segments_.size()); - fprof.kernel_profiles.resize(fp->segments_.size()); - NVFUSER_CUPTI_SAFE_CALL(cuptiActivityFlushAll(0)); - NVF_CHECK( - fp->kernel_profiles_.size() >= fp->segments_.size(), - "All of the kernel profiles have not been recorded!"); - + fprof.kernel_profiles.resize(fp->segments_.size()); for (auto& kprof : fp->kernel_profiles_) { auto corr_id = kprof.correlation_id; if (fp->corrid_2_segid_.count(corr_id) == 0) { diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 00dd1ca7bc1..e11d2f2d8a2 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1867,7 +1867,7 @@ std::unique_ptr SegmentedFusion::makeFusion(SegmentedGroup* sg) { std::unique_ptr SegmentCandidateFinder::segment( std::unique_ptr fusion, - const KernelArgumentHolder& inputs, + const KernelArgumentHolder* inputs, SchedulerRuntimeInfo& runtime_info) { if (!hasSegmentHints(fusion.get())) { scheduler_debug_utils::canScheduleMessage( @@ -1876,7 +1876,7 @@ std::unique_ptr SegmentCandidateFinder::segment( SchedulerEntry::proposeHeuristics(fusion.get(), runtime_info); if (maybe_complete_fusion_heuristic.has_value()) { return SegmentedFusion::fromCompleteFusion( - std::move(fusion), maybe_complete_fusion_heuristic.value(), inputs); + std::move(fusion), maybe_complete_fusion_heuristic.value(), *inputs); } } if (fusion) { @@ -2584,7 +2584,7 @@ class TranslateApplicableWelford { //! group containing all the welford ops needs to be //! provided. bool wouldTranslateToPersistent( - const std::vector& orignal_welfords, + const std::vector& original_welfords, SegmentedGroup* group = nullptr); //! Translate the given welford op into separate @@ -2617,12 +2617,12 @@ TranslateApplicableWelford::TranslateApplicableWelford( const KernelArgumentHolder& runtime_inputs) : runtime_inputs_(runtime_inputs) { auto exprs = fusion->exprs(); - std::vector orignal_welfords( + std::vector original_welfords( ir_utils::filterByType(exprs).begin(), ir_utils::filterByType(exprs).end()); - if (wouldTranslateToPersistent(orignal_welfords)) { - for (auto welford : orignal_welfords) { + if (wouldTranslateToPersistent(original_welfords)) { + for (auto welford : original_welfords) { translateSingleWelford(welford); } translated_any_welford_ = true; @@ -2707,26 +2707,26 @@ bool TranslateApplicableWelford::isValidPersistentFusion( // Note that when segmented it is assumed that insertion of lower // precision cast has already been done bool TranslateApplicableWelford::wouldTranslateToPersistent( - const std::vector& orignal_welfords, + const std::vector& original_welfords, SegmentedGroup* group) { - if (orignal_welfords.empty()) { + if (original_welfords.empty()) { return false; } // Make sure all welford inputs are not already statistics, e.g. // FusionSqueezeOnlyWelford_CUDA - for (auto welford : orignal_welfords) { + for (auto welford : original_welfords) { if (!welford->inN()->isOneInt()) { return false; } } // Make sure all welford ops come from the same complete fusion - auto fusion = orignal_welfords[0]->fusion(); + auto fusion = original_welfords[0]->fusion(); NVF_ERROR( std::all_of( - orignal_welfords.begin(), - orignal_welfords.end(), + original_welfords.begin(), + original_welfords.end(), [fusion](WelfordOp* welford) { return welford->fusion() == fusion; }), "Welfords in given vector not in the same fusion"); @@ -2736,8 +2736,8 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent( std::vector copied_welfords; std::transform( - orignal_welfords.begin(), - orignal_welfords.end(), + original_welfords.begin(), + original_welfords.end(), std::back_inserter(copied_welfords), [&original_to_test_map](auto welford) { return original_to_test_map.clone(welford); @@ -3559,9 +3559,7 @@ bool SegmentCandidateFinder::codeGenSupportedMerge( } return true; } - NVF_ERROR(runtime_info_.has_value(), "needs runtime info"); - auto h = - tryMerge(segmented_fusion_.get(), runtime_info_.value(), group1, group2); + auto h = tryMerge(segmented_fusion_.get(), runtimeInfo(), group1, group2); return h.has_value(); } @@ -3569,10 +3567,12 @@ bool SegmentCandidateFinder::codeGenSupportedMerge( // called twice ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic( SegmentedGroup* group) { - if (!runtime_info_.has_value()) { + if (options_.only_segment_resharding_exprs) { + // We don't need to generate a heuristic for multidevice segments at this + // moment return ScheduleHeuristic::None; } - auto h = tryMerge(segmented_fusion_.get(), runtime_info_.value(), group); + auto h = tryMerge(segmented_fusion_.get(), runtimeInfo(), group); NVF_ERROR( h.has_value(), "Can not find a scheduler to schedule fusion segment"); return h.value(); @@ -3580,20 +3580,21 @@ ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic( SegmentCandidateFinder::SegmentCandidateFinder( std::unique_ptr fusion, - SegmentCandidateFinderOptions options) - : options_(options) { - segmented_fusion_ = std::make_unique(std::move(fusion)); - findSegments(); -} - -SegmentCandidateFinder::SegmentCandidateFinder( - std::unique_ptr fusion, - const KernelArgumentHolder& inputs, + const KernelArgumentHolder* inputs, SegmentCandidateFinderOptions options) : options_(options), runtime_info_( - std::make_optional(fusion.get(), inputs)), - runtime_inputs_(std::make_optional(inputs)) { + inputs == nullptr ? std::nullopt + : std::make_optional( + fusion.get(), + *inputs)), + runtime_inputs_(inputs) { + NVF_ERROR( + !options_.only_segment_resharding_exprs || + (!options_.run_translate_welford && + !options_.run_combine_reductions && options_.run_herrmann_merge && + options_.run_final_merge), + "Invalid Segmenter options"); segmented_fusion_ = std::make_unique(std::move(fusion)); findSegments(); } @@ -3712,9 +3713,9 @@ void SegmentCandidateFinder::findSegments() { ir_utils::hasOpsOfType(segmented_fusion_->completeFusion()); if (options_.run_translate_welford && has_welford_ops) { - NVF_ERROR(runtime_inputs_.has_value()); + NVF_ERROR(runtime_inputs_); if (TranslateApplicableWelford::run( - segmented_fusion_.get(), runtime_inputs_.value())) { + segmented_fusion_.get(), *runtime_inputs_)) { // If modified, rebuild segments as existing expressions may be // pulled into welford groups buildInitialSegments(); @@ -3729,9 +3730,8 @@ void SegmentCandidateFinder::findSegments() { group->setHeuristic(deriveHeuristic(group)); } } - // Remove all scalar edges since they do not represent actual - // dependency among segmented groups. + // dependency among segmented groups. removeScalarEdges(); // Run pre-merge heuristics diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index 63e0854db15..4f26054ba86 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -559,7 +559,7 @@ class SegmentCandidateFinder { // Perform segmentation on a copy of the given fusion static std::unique_ptr segment( const Fusion* fusion, - const KernelArgumentHolder& inputs, + const KernelArgumentHolder* inputs, SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) { auto fusion_copy = std::make_unique(*fusion); return segment(std::move(fusion_copy), inputs, options); @@ -568,7 +568,7 @@ class SegmentCandidateFinder { // Perform segmentation on and take ownership of the given fusion static std::unique_ptr segment( std::unique_ptr fusion, - const KernelArgumentHolder& inputs, + const KernelArgumentHolder* inputs, SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) { if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { debug() << "Segment the fusion (Original Fusion Un-modified): " @@ -579,22 +579,9 @@ class SegmentCandidateFinder { return std::move(scf.segmented_fusion_); } - // Perform segmentation on and take ownership of the given fusion - static std::unique_ptr segment( - std::unique_ptr fusion, - SegmentCandidateFinderOptions options) { - if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { - debug() << "Segment the fusion (Original Fusion Un-modified): " - << std::endl; - fusion->printMath(); - } - SegmentCandidateFinder scf(std::move(fusion), options); - return std::move(scf.segmented_fusion_); - } - static std::unique_ptr segment( std::unique_ptr fusion, - const KernelArgumentHolder& inputs, + const KernelArgumentHolder* inputs, SchedulerRuntimeInfo& runtime_info); static bool hasSegmentHints(Fusion* fusion); @@ -607,11 +594,7 @@ class SegmentCandidateFinder { // Perform segmentation on and take ownership of the given fusion SegmentCandidateFinder( std::unique_ptr fusion, - const KernelArgumentHolder& inputs, - SegmentCandidateFinderOptions options); - - SegmentCandidateFinder( - std::unique_ptr fusion, + const KernelArgumentHolder* inputs, SegmentCandidateFinderOptions options); void resetTraversal(); @@ -660,8 +643,7 @@ class SegmentCandidateFinder { } ExpressionEvaluator& expressionEvaluator() { - NVF_ERROR(runtime_info_.has_value(), "needs runtime info"); - return runtime_info_->expressionEvaluator(); + return runtimeInfo().expressionEvaluator(); } //! Additional merging iteration, clean up the rest of @@ -771,6 +753,8 @@ class SegmentCandidateFinder { // unary ops on inputs to the complete fusion VectorOfUniqueEntries excluded_inp_unary_exprs_; + // This is allowed to be null in the multidevice case where the segmenter is + // used for breaking the fusion into compute and communication segments std::optional runtime_info_; //! Note: @@ -789,7 +773,7 @@ class SegmentCandidateFinder { //! TODO: //! implement the expression evaluator transfer and //! remove runtime_inputs_ in a follow up. - std::optional runtime_inputs_; + const KernelArgumentHolder* runtime_inputs_; }; // TODO: Make as member functions on classes instead of global scope diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 0790586fe19..c83c7241ea4 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -336,14 +336,29 @@ std::string IdModel::toString() const { ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) { ValGraph id_graph(propagate_through_exprs); + // To deterministically initialize the graph, the order of adding + // domains must be deterministic. Here, we sort all IDs by their + // names. + + std::vector all_ids; + all_ids.reserve(id_definitions_.size()); for (const auto& [id, defs] : id_definitions_) { + all_ids.push_back(id); + } + + std::sort( + all_ids.begin(), all_ids.end(), [](IterDomain* id1, IterDomain* id2) { + return id1->name() < id2->name(); + }); + + for (auto id : all_ids) { auto uses_it = id_uses_.find(id); NVF_ERROR( uses_it != id_uses_.end(), "Failed to initialize id: ", id->toString(), " as it's missing a definition entry."); - id_graph.initializeVal(id, defs, uses_it->second); + id_graph.initializeVal(id, id_definitions_.at(id), uses_it->second); } return id_graph; diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 151fa036b99..ebae8db5679 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1445,6 +1445,10 @@ class MmaOp : public Expr { return attribute(ATTR_POS_BATCH_AXES); } + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; + private: // Predefined idexes of attributes stored for this IR node, to avoid // magic numbers, based on order in which attributes are initialized diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index e9065474985..004f2e6330f 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -1309,8 +1309,6 @@ SqueezeOp::SqueezeOp( id->isBroadcast() || id->isSymbolic(), "Squeeze dimension should be either Symbolic or Broadcast. Found ", id->getIterType()); - NVF_ERROR( - !id->hasExpandedExtent(), "Can not squeeze expanded dimension(s)."); if (id->isBroadcast()) { // Check concrete broadcast extent here. For Symbolic inputs, this check // will be deferred to concretization. See dynamic_transform.cpp @@ -1356,12 +1354,20 @@ std::vector SqueezeOp::evaluate( NVF_ERROR( (int64_t)is_squeeze_dims.size() == in.dim(), "The dimensions of input tensor and does not match with is_squeeze_dims"); + at::Tensor out = in; for (int64_t i : c10::irange((int64_t)is_squeeze_dims.size())) { - if (!is_squeeze_dims[i]) { + if (is_squeeze_dims[i]) { + if (in.stride(i) == 0) { + // If the input dimension is expanded in this dimension, undo the expand + // by slicing. This ensures that any broadcast dimensions will be + // unexpanded when we do the final call to view() + out = out.slice(i, 0, 1); + } + } else { out_shape.push_back(in.sizes()[i]); } } - return {in.view(out_shape)}; + return {out.view(out_shape)}; } void SqueezeOp::checkConcretization(Val* old_val, Val* new_val) const { @@ -2063,6 +2069,63 @@ void MmaOp::setMacro(MmaMacro macro) { attribute(ATTR_POS_MACRO) = macro; } +std::vector MmaOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + const auto tv_a = inA()->as(); + const auto tv_b = inB()->as(); + NVF_CHECK( + + tv_a->nDims() == tv_b->nDims(), + "Either both or none of A and B should be batch"); + // Verify that the broadcasted size is 3. + NVF_CHECK( + tv_a->nDims() == 3, + "MmaOp::evaluate is not implemented for size: ", + tv_a->nDims()); + + // Assumptions: + // Currently, the evaluate method assumes that the MmaOp is preceded by a + // broadcast. The inputs to MmaOp are broadcasted as the last dim for the + // first operand and the first dim for the second operand. + // The inputs here will be [M, K, 1] x [1, K, N]. + NVF_CHECK( + input(0)->definition() != nullptr && + input(0)->definition()->isA(), + "Currently, MmaOp::evaluate assumes the preceding op to be a broadcast."); + NVF_CHECK( + input(1)->definition() != nullptr && + input(1)->definition()->isA(), + "Currently, MmaOp::evaluate assumes the preceding op to be a broadcast."); + + NVF_CHECK( + tv_a->getRootDomain().back()->isBroadcast(), + "Expected last dimension to be broadcasted for first operand."); + NVF_CHECK( + tv_b->getRootDomain().front()->isBroadcast(), + "Expected first dimension to be broadcasted for second operand."); + + // Squeeze the inputs to remove the broadcasted dimensions. + const auto in_a = inputs.at(0).as().squeeze(-1); + const auto in_b = inputs.at(1).as().squeeze(0); + + // After removing the broadcast dimensions, the format should be + // [M, K] x [K, N] compatible with aten::matmul format. + auto output = in_a.matmul(in_b); + + // ATen preserves the input dtype whereas MmaOP generates float outputs. + // Cast to the dtype of the MmaOp output for consistency. + // NOTE: MmaOp returns the float output, whereas in the evaluate method, + // we are casting from float -> input_dtype -> float. This will lead + // to loss of precision. + // MmaOp::evaluate should be modified to effectively handle cast(MmaOp(H, + // H), H) This will avoid the above cast chain and precision issue. + if (tv_a->getDataType() != out()->getDataType().value()) { + output = output.to(data_type_to_aten(out()->getDataType().value())); + } + return {output}; +} + NVFUSER_DEFINE_CLONE_AND_CREATE(MmaOp) ExpandOp::ExpandOp( @@ -2103,7 +2166,7 @@ std::vector ExpandOp::evaluate( for (auto i : c10::irange(1, inputs.size())) { expanded_size.push_back((int64_t)inputs.at(i)); } - return {at::expand_copy(in, expanded_size)}; + return {in.expand(expanded_size)}; } NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp) diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index caacf9a2d3a..1c6f26e7c40 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -743,15 +743,6 @@ std::vector allIDsOf(const TensorView* tv) { return std::vector(all_ids.begin(), all_ids.end()); } -bool isSelectInput(TensorView* tv) { - for (auto expr : tv->uses()) { - if (expr->isA()) { - return true; - } - } - return false; -} - bool isIndexSelectLookupTv(const TensorView* tv) { for (auto expr : tv->uses()) { if (expr->isA()) { diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index b760c561049..2fc594a53c9 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -441,9 +441,6 @@ IterDomain* getConsumerOfIndexedProducerID(const Expr* expr); // unique. std::vector allIDsOf(const TensorView* tv); -// Check if the given tv is an input of SelectOp -bool isSelectInput(TensorView* tv); - // Check if the given tv is first argment of index_select(lookup, dim, indices) bool isIndexSelectLookupTv(const TensorView* tv); @@ -574,6 +571,17 @@ bool hasOpsOfType(Fusion* fusion) { return false; } +//! Returns true if tv is used by any ops of the given type. +template +bool isTvUsedByOpsOfType(TensorView* tv) { + for (auto expr : tv->uses()) { + if (expr->isOneOf()) { + return true; + } + } + return false; +} + //! Returns expressions that are of type ReductionOp, GroupedReductionOp, or //! WelfordOp. std::vector getAllTypesOfReductionOps(Fusion* fusion); diff --git a/csrc/kernel.cpp b/csrc/kernel.cpp index 99214dc6f6b..e8dc1eba40b 100644 --- a/csrc/kernel.cpp +++ b/csrc/kernel.cpp @@ -128,7 +128,11 @@ class KernelIrScanner : private IrVisitor { } void handle(GridReduction* grid_reduction) final { - summary_.has_grid_reductions = true; + // summary.has_grid_reductions is used to determine whether we need a + // reduction workspace. Serial grid reductions do not require this + // workspace. + summary_.has_grid_reductions = + grid_reduction->serialReductionTensor() == nullptr; if (grid_reduction->isAllreduce()) { summary_.has_cooperative_grid_reduction = true; } diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 256db382add..16a6e787d56 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -1046,7 +1046,7 @@ FusionKernelRuntime::FusionKernelRuntime( // Default compilation path applies segmentation before scheduling and // compiling the fusion. segmented_fusion_ = - SegmentCandidateFinder::segment(std::move(fusion), args, runtime_info); + SegmentCandidateFinder::segment(std::move(fusion), &args, runtime_info); } else { // Serialization path that generates segmented fusion from flatbuffers. // Convert Welford to two-pass if option is enabled and the original diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 6bf5af5c973..0c1456c50c9 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -5,12 +5,13 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #ifdef USE_C10D_NCCL #include #endif #include +#include namespace nvfuser { namespace { @@ -236,7 +237,13 @@ c10::intrusive_ptr Reduce::post( #ifdef USE_C10D_NCCL auto nccl_backend = dynamic_cast(team_backend.get()); if (nccl_backend) { +#if NVF_TORCH_VERSION_GREATER(2, 2, 0) + // API change https://github.com/pytorch/pytorch/pull/119421 + return nccl_backend->_reduce_oop( + buf.at(0), params_.src_bufs.at(0), options); +#else return nccl_backend->_reduce_oop(buf, params_.src_bufs, options); +#endif } #endif if (comm.deviceId() == params_.root) { diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 9d3a38d8c8f..7c000176b45 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -6,7 +6,7 @@ */ // clang-format on #pragma once -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index d387dd2c6ab..8ee90bebbde 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index 454c8a585df..4b9f0d26f6f 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -6,7 +6,7 @@ */ // clang-format on #pragma once -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index bd68017445c..74d43cc9fd8 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include #include @@ -105,7 +105,8 @@ MultiDeviceExecutor::MultiDeviceExecutor( .run_final_merge = true, .only_segment_resharding_exprs = true}; - staged_fusion_ = SegmentCandidateFinder::segment(std::move(fusion), options); + staged_fusion_ = + SegmentCandidateFinder::segment(std::move(fusion), nullptr, options); for (auto group : staged_fusion_->groups()) { NVF_ERROR(!group->exprs().empty() == 1, "invalid segmentation"); diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index 722b7627096..cafd4988ec4 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #pragma once #include diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index efee370d3fa..c910a2d10e0 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include #include @@ -529,8 +529,8 @@ bool isLowerableToCommunication(Expr* expr) { // get the reduced axis std::vector reduction_axis; std::copy_if( - out->getMaybeRFactorDomain().begin(), - out->getMaybeRFactorDomain().end(), + out->getRootDomain().begin(), + out->getRootDomain().end(), std::back_inserter(reduction_axis), [](IterDomain* id) { return id->isReduction(); }); // check whether the reduction involves only one axis @@ -549,4 +549,18 @@ bool isLowerableToCommunication(Expr* expr) { } // namespace nvfuser +#else // NVFUSER_DISTRIBUTED + +#include + +namespace nvfuser { + +// This is just here so that things can compile even when/if NVFUSER_DISTRIBUTED +// is not defined. The code paths aren't intended to be hit ever in such cases, +// so the implementation is unimportant. +bool isLowerableToCommunication(Expr*) { + return false; +} + +} // namespace nvfuser #endif diff --git a/csrc/multidevice/lower_communication.h b/csrc/multidevice/lower_communication.h index 1355e73ed92..125b7c3f76a 100644 --- a/csrc/multidevice/lower_communication.h +++ b/csrc/multidevice/lower_communication.h @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #pragma once #include @@ -27,4 +27,12 @@ std::vector> lowerCommunication( at::Tensor output_tensor); } // namespace nvfuser +#else // NVFUSER_DISTRIBUTED + +namespace nvfuser { + +bool isLowerableToCommunication(Expr*); + +} + #endif diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index aec7a86c21d..3c4933be5d4 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -232,46 +232,13 @@ TensorView* squeeze(TensorView* x, const std::vector& dims) { to_squeeze[dim] = true; } - std::vector out_domain; - for (const auto idx : c10::irange(ndims)) { - auto id = x_dom[idx]; - if (to_squeeze[idx]) { - if (!id->isSymbolic()) { - // If a squeeze is attempted on a non-broadcast dimension - // just don't do it! This conforms with Pytorch. - if (!id->isBroadcast()) { - to_squeeze[idx] = false; - out_domain.push_back(id->cloneWithoutRFactor()); - continue; - } - NVF_CHECK( - !id->hasExpandedExtent(), "Can not squeeze expanded dimension(s)."); - NVF_CHECK( - id->extent()->isConstScalar() && id->extent()->evaluate() == 1, - "Can not squeeze dimension(s) with size != 1."); - } - } else { - out_domain.push_back(id->cloneWithoutRFactor()); - } - } - - auto out = IrBuilder::create( - IrBuilder::create( - out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), - *x->getDataType()); - - std::vector all_false(to_squeeze.size(), false); - // If a squeeze does not perform a squeeze, create a no-op - if (to_squeeze == all_false) { - IrBuilder::create(LoadStoreOpType::Set, out, x); - } else { - IrBuilder::create(x->container(), out, x, to_squeeze); - } - - return out; + return squeeze(x, to_squeeze); } -TensorView* squeeze(TensorView* x, const std::vector& to_squeeze) { +TensorView* squeeze( + TensorView* x, + const std::vector& to_squeeze, + bool squeeze_expanded) { NVF_ERROR(x != nullptr, "Input is invalid."); auto x_dom = x->domain()->noReductions(); const auto ndims = static_cast(x_dom.size()); @@ -292,7 +259,10 @@ TensorView* squeeze(TensorView* x, const std::vector& to_squeeze) { id->isBroadcast(), "Can not squeeze non-broadcasting dimension(s)."); NVF_CHECK( - !id->hasExpandedExtent(), "Can not squeeze expanded dimension(s)."); + squeeze_expanded || !id->hasExpandedExtent(), + "Refusing to squeeze expanded IterDomain ", + id->toString(), + ". To force removal of this axis, use squeeze_expanded=true."); NVF_CHECK( id->extent()->isConstScalar() && id->extent()->evaluate() == 1, "Can not squeeze dimension(s) with size != 1."); @@ -307,7 +277,13 @@ TensorView* squeeze(TensorView* x, const std::vector& to_squeeze) { out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), *x->getDataType()); - IrBuilder::create(x->container(), out, x, to_squeeze); + if (std::none_of( + to_squeeze.begin(), to_squeeze.end(), [](bool b) { return b; })) { + // If we did not squeeze any axes, this is just set() + IrBuilder::create(LoadStoreOpType::Set, out, x); + } else { + IrBuilder::create(x->container(), out, x, to_squeeze); + } return out; } diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 01e7ca16e74..95aa1084457 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -43,12 +43,26 @@ TensorView* reshape(TensorView* x, const std::vector& new_sizes); TensorView* flatten(TensorView* x, int64_t start_dim = 0, int64_t end_dim = -1); -// This implementation is specific to Pytorch where if you attempt to squeeze -// a non-broadcast dimension, the squeeze does not do anything to that -// dimension and does not trigger an error. +//! Squeeze the selected dimensions. +//! +//! NOTE: This function throws an error when encountering an unsqueezable +//! dimension. This behavior differs from PyTorch. TensorView* squeeze(TensorView* x, const std::vector& dims); -TensorView* squeeze(TensorView* x, const std::vector& to_squeeze); +//! Squeeze the dimensions corresponding to "true" in to_squeeze, i.e. remove +//! those broadcasted dimensions. +//! +//! NOTE: This function throws an error when encountering an unsqueezable +//! dimension. This behavior differs from PyTorch. +//! +//! If squeeze_expanded is true, then expanded Broadcasts will be removed just +//! as if they were not expanded. If squeeze_expanded is false, then it is an +//! error for an expanded broadcast to have a corresponding "true" value in +//! to_squeeze. +TensorView* squeeze( + TensorView* x, + const std::vector& to_squeeze, + bool squeeze_expanded = false); TensorView* unsqueeze(TensorView* x, int dim); diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 64988264764..e7607dbac84 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -1291,36 +1291,85 @@ TensorView* reductionOp( return maybe_full; } + // [Trivial reductions] + // When we reduce a simple broadcast axis like bS0{1} the effect is just to + // squeeze out that broadcast axis. When the axis is expanded, such as bS1{1 + // ex i0}, then the effect depends on the op type. We have the following + // mappings from op_type to expanded reduction equivalent: + // Add -> multiplication by i0 + // Mul -> raise to power i0 + // Min/Max -> squeeze + // {Logical,Bitwise}{And,Or} -> squeeze + // BitwiseXor -> 0 if i0 is even, else squeeze + // Eq -> squeeze + // Gcd -> squeeze + // Other op-types are non-commutative, so we ignore them here as they should + // not be used in reductions. We can see that the only two common ops that + // require special consideration are Add and Mul. Currently Xor is not + // supported for expanded reduction. We treat all others as trivial (i.e. + // squeeze). std::vector reduction_axes; - std::vector is_trivial_reduction(ndims, false); + std::vector is_squeeze(ndims, false); + bool expand_reductions_are_trivial = reduction_op_type != BinaryOpType::Add && + reduction_op_type != BinaryOpType::Mul && + reduction_op_type != BinaryOpType::BitwiseXor; int offset = 0; for (unsigned int axis : uint_axes) { auto id = tv_root[axis]; - is_trivial_reduction[axis] = id->isBroadcast() && - !id->hasExpandedExtent() && id->extent()->isConstInt() && - id->extent()->evaluate().as() == 1; - if (!is_trivial_reduction[axis]) { - reduction_axes.push_back((int)axis + offset); - } else if (!keep_dim) { + if (id->isBroadcast()) { + is_squeeze[axis] = true; offset--; + } else { + reduction_axes.push_back((int)axis + offset); } } TensorView* squeezed = tv; if (offset < 0) { - squeezed = squeeze(tv, is_trivial_reduction); + // There are some broadcast dims being reduced. We squeeze them all first. + squeezed = squeeze(tv, is_squeeze, /*squeeze_expanded=*/true); } TensorView* out = squeezed; if (!reduction_axes.empty()) { - return reductionOpRaw( + out = reductionOpRaw( reduction_op_type, reduction_axes, init, squeezed, keep_dim, dtype); } + if (!expand_reductions_are_trivial) { + Val* factor = nullptr; + for (auto axis : uint_axes) { + IterDomain* id = tv_root[axis]; + if (id->isBroadcast() && id->hasExpandedExtent()) { + factor = + SimplifyingIrBuilder::mulExpr(factor, id->getMaybeExpandedExtent()); + } + } + if (factor != nullptr) { + factor = SimplifyingIrBuilder::maybeCastExpr(out->dtype(), factor); + if (reduction_op_type == BinaryOpType::Add) { + out = mul(out, factor); + } else if (reduction_op_type == BinaryOpType::Mul) { + out = pow(out, factor); + } else { + NVF_ERROR( + false, + "Add and Mul are the only non-trivial expand reductions allowed"); + } + } + } + + if (keep_dim && offset < 0) { + // There were squeezed dimension removed from squeeze that will not be + // restored by reductionOpRaw above, so we restore them here + out = broadcast(out, is_squeeze); + } + if (out == tv) { // makes sure that a new tensor is created return set(tv); } + return out; } @@ -1767,7 +1816,7 @@ WelfordResult Welford( TensorView* squeezed = tv; if (offset < 0) { - squeezed = squeeze(tv, is_trivial_reduction); + squeezed = squeeze(tv, is_trivial_reduction, /*squeeze_expanded=*/true); } if (!reduction_axes.empty()) { diff --git a/csrc/ops/normalization.cpp b/csrc/ops/normalization.cpp index e2ac72a1a21..b77222ed942 100644 --- a/csrc/ops/normalization.cpp +++ b/csrc/ops/normalization.cpp @@ -559,19 +559,19 @@ ForwardNormResult batch_norm( auto cast_output = castOp(*rm_dtype, aliased_output); fusion->aliasOutputToInput( - cast_output, input_to_cast, AliasType::InplaceUpdate); + cast_output, input_to_cast, AllocationType::InplaceUpdate); }; if (running_mean->isFusionInput()) { fusion->aliasOutputToInput( - new_mean_hat, running_mean, AliasType::InplaceUpdate); + new_mean_hat, running_mean, AllocationType::InplaceUpdate); } else { cast_to_input_dtype(running_mean, new_mean_hat); } if (running_var->isFusionInput()) { fusion->aliasOutputToInput( - new_var_hat, running_var, AliasType::InplaceUpdate); + new_var_hat, running_var, AllocationType::InplaceUpdate); } else { cast_to_input_dtype(running_var, new_var_hat); } @@ -809,7 +809,7 @@ ForwardNormResult instance_norm( castOp(running_mean->getDataType().value(), new_mean_channels_only); } fusion->aliasOutputToInput( - new_mean_channels_only, running_mean, AliasType::InplaceUpdate); + new_mean_channels_only, running_mean, AllocationType::InplaceUpdate); auto num_feature_decrement = sub(N, x->container()->oneVal(N->dtype())); auto unbiased_var = @@ -828,7 +828,7 @@ ForwardNormResult instance_norm( castOp(running_var->getDataType().value(), new_var_channels_only); } fusion->aliasOutputToInput( - new_var_channels_only, running_var, AliasType::InplaceUpdate); + new_var_channels_only, running_var, AllocationType::InplaceUpdate); } mean = welford_out.avg; diff --git a/csrc/python_frontend/fusion_cache.cpp b/csrc/python_frontend/fusion_cache.cpp index 7b408d5014b..4479153f378 100644 --- a/csrc/python_frontend/fusion_cache.cpp +++ b/csrc/python_frontend/fusion_cache.cpp @@ -42,14 +42,19 @@ std::string getSerdeTmpFile() { return ss.str(); } -std::string getSerdeFile() { - auto device_prop = at::cuda::getCurrentDeviceProperties(); +std::string getSerdeFile(std::optional device_id) { + auto device_prop = (device_id.has_value()) + ? at::cuda::getDeviceProperties(device_id.value()) + : at::cuda::getCurrentDeviceProperties(); int cuda_major = 0; int cuda_minor = 0; NVFUSER_NVRTC_SAFE_CALL(nvrtcVersion(&cuda_major, &cuda_minor)); std::stringstream ss; ss << "nvf_serde"; + if (device_id.has_value()) { + ss << "_rank" << device_id.value(); + } ss << "_device" << device_prop->major << "_" << device_prop->minor; ss << "_cuda" << cuda_major << "_" << cuda_minor; return ss.str(); @@ -90,7 +95,9 @@ BinaryBuffer openFusionCache(std::string filename) { } // This check function only throws errors if strict flag is enabled. -const serde::FusionCache* verifyFusionCache(const BinaryBuffer& buffer) { +const serde::FusionCache* verifyFusionCache( + const BinaryBuffer& buffer, + std::optional device_id) { FUSER_PERF_SCOPE("Flatbuffers::verifyFusionCache"); auto fusion_cache_buffer = serde::GetFusionCache(buffer.data()); @@ -106,7 +113,9 @@ const serde::FusionCache* verifyFusionCache(const BinaryBuffer& buffer) { "Failed to verify the schema version of the FusionCache buffer"); // Check device major and minor versions - auto device_prop = at::cuda::getCurrentDeviceProperties(); + auto device_prop = (device_id.has_value()) + ? at::cuda::getDeviceProperties(device_id.value()) + : at::cuda::getCurrentDeviceProperties(); NVF_CHECK( device_prop->major == fusion_cache_buffer->device_major() && device_prop->minor == fusion_cache_buffer->device_minor(), @@ -151,7 +160,8 @@ void serialize() { // Files replaced through this process should remain extant if they are being // read because of UNIX filesystem properties, but this behavior is // unverified. - auto file_path = getSerdeFilePath(getSerdeFile()); + auto file_path = + getSerdeFilePath(getSerdeFile(FusionCache::get()->deviceId())); std::error_code rename_ec; fs::rename(tmp_file_path, file_path, rename_ec); @@ -230,14 +240,16 @@ flatbuffers::Offset TrieNode::serialize( FusionCache* FusionCache::get( size_t max_fusions, + std::optional selected_device, bool load_from_default_workspace) { FUSER_PERF_SCOPE("FusionCache::get"); std::lock_guard guard(singleton_lock_); if (singleton_ == nullptr) { - singleton_ = new FusionCache(max_fusions); + singleton_ = new FusionCache(max_fusions, selected_device); // Deserialize cache hierarchy from common workspace automatically - auto file_path = getSerdeFilePath(getSerdeFile()).native(); + auto file_path = + getSerdeFilePath(getSerdeFile(singleton_->deviceId())).native(); if (load_from_default_workspace && fs::exists(file_path)) { try { singleton_->deserialize(file_path); @@ -270,7 +282,7 @@ FusionCache* FusionCache::get( // Reset FusionCache if there is an issue with the current workspace. delete singleton_; - singleton_ = new FusionCache(max_fusions); + singleton_ = new FusionCache(max_fusions, selected_device); } } } @@ -285,6 +297,10 @@ size_t FusionCache::numFusions() const { return fusions_.size(); } +std::optional FusionCache::deviceId() const { + return device_id_; +} + void FusionCache::print(std::ostream& os) const { os << "Fusions by id:" << std::endl; std::vector stack; @@ -346,14 +362,18 @@ void FusionCache::stats(std::ostream& os) const { void FusionCache::reset() { std::lock_guard guard(singleton_lock_); if (singleton_ != nullptr) { - auto max_fusions = singleton_->max_fusions_; + size_t max_fusions = singleton_->max_fusions_; + std::optional device_id = singleton_->device_id_; delete singleton_; - singleton_ = new FusionCache(max_fusions); + singleton_ = new FusionCache(max_fusions, device_id); } } -FusionCache::FusionCache(size_t max_fusions) +FusionCache::FusionCache( + size_t max_fusions, + std::optional selected_device) : max_fusions_(max_fusions), + device_id_(selected_device), root_(nullptr), fusions_(), terminal_nodes_(), @@ -598,7 +618,8 @@ void FusionCache::deserialize(std::string filename) { fusions_.empty(), "Deserialization is prohibited if FusionCache is already populated."); const BinaryBuffer& buffer = openFusionCache(filename); - const serde::FusionCache* fusion_cache_buffer = verifyFusionCache(buffer); + const serde::FusionCache* fusion_cache_buffer = + verifyFusionCache(buffer, device_id_); // See table definition for FusionCache in serde/fusion_cache.fbs FUSER_PERF_SCOPE("FusionCache::deserialize"); diff --git a/csrc/python_frontend/fusion_cache.h b/csrc/python_frontend/fusion_cache.h index fd45886c24a..13b555d70de 100644 --- a/csrc/python_frontend/fusion_cache.h +++ b/csrc/python_frontend/fusion_cache.h @@ -118,7 +118,7 @@ struct TrieNode { class FusionCache { //! The constructor is private given the FusionCache is only constructed //! as a singleton. - FusionCache(size_t max_fusions); + FusionCache(size_t max_fusions, std::optional selected_device); public: //! Copy and Assignment of the FusionCache is not supported @@ -130,10 +130,13 @@ class FusionCache { //! Gets a pointer to the singleton and creates a new one if necessary static FusionCache* get( - size_t max_fusions = 8192, + size_t max_fusions = 16384, + std::optional selected_device = std::nullopt, bool load_from_default_workspace = true); //! Number of fusions cached size_t numFusions() const; + //! Get device associated with this FusionCache + std::optional deviceId() const; //! print cache contents void print(std::ostream& os) const; //! print cache stats @@ -184,6 +187,9 @@ class FusionCache { //! The max allowed number of fusions in the cache size_t max_fusions_; + //! A separate process is created for each device in a distributed setting. + //! Each FusionCache becomes associated with a device. + std::optional device_id_; //! The root (start) of the prefix tree to start a cache look up of a given //! fusion definition. std::unique_ptr root_; diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 683ff8aaa48..a2e332ee60f 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -718,13 +718,15 @@ struct BroadcastInDimOpRecord : RecordFunctor { const auto arg_ndims = arg_domains_nr.size(); NVF_CHECK( output_ndims_ >= arg_ndims, - "The new shape is expected to be greater-then-or-equal to the input", + "The new shape is expected to be greater-then-or-equal to the input: ", output_ndims_, + " vs ", arg_ndims); NVF_CHECK( arg_ndims == broadcast_dims_.size(), - "The broadcast dimensions should match the input dimensions.", + "The broadcast dimensions should match the input dimensions: ", arg_ndims, + " vs ", broadcast_dims_.size()); std::vector is_broadcast_dim(output_ndims_, true); @@ -1564,7 +1566,7 @@ struct ReductionOpRecord : RecordFunctor { void print(std::ostream& os, bool close_function = true) const final { RecordFunctor::print(os, false); - os << ", axes=["; + os << ", dims=["; bool first_arg = true; for (auto axis : axes_) { if (first_arg) { @@ -2090,7 +2092,7 @@ struct NormOpRecord : RecordFunctor { void print(std::ostream& os, bool close_function = true) const final { RecordFunctor::print(os, false); - os << ", axes=["; + os << ", dims=["; bool first_arg = true; for (auto axis : axes_) { if (first_arg) { diff --git a/csrc/python_frontend/fusion_state.cpp b/csrc/python_frontend/fusion_state.cpp index 4625d80c1c8..d827d181ed5 100644 --- a/csrc/python_frontend/fusion_state.cpp +++ b/csrc/python_frontend/fusion_state.cpp @@ -161,9 +161,9 @@ void FusionState::addOutput( void FusionState::aliasOutputToInput(Val* output, Val* input) { NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); - // We haven't exposed AliasType to Python API. For now, use + // We haven't exposed AllocationType to Python API. For now, use // InplaceUpdate to preserve the old behavior. - fusion_->aliasOutputToInput(output, input, AliasType::InplaceUpdate); + fusion_->aliasOutputToInput(output, input, AllocationType::InplaceUpdate); } } // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 98e9ed7bb1f..2a4f3d17406 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -55,9 +55,6 @@ Vector define_vector_fn( std::vector args; size_t idx = 0; for (const auto& item : values) { - NVF_CHECK( - idx < 8, - "The specified vector size exceeds the max tensor size for nvfuser."); if (py::isinstance(item)) { auto int_value = py::cast(item); NVF_CHECK( @@ -418,7 +415,8 @@ void initNvFuserPythonBindings(PyObject* module) { .def_static( "get", &FusionCache::get, - py::arg("max_fusions") = int(8192), + py::arg("max_fusions") = int(16384), + py::arg("selected_device") = int(-1), py::arg("load_from_default_workspace") = true, py::return_value_policy::reference) .def("num_fusions", &FusionCache::numFusions) @@ -715,10 +713,10 @@ void initNvFuserPythonBindings(PyObject* module) { .def( "define_tensor", [](FusionDefinition& self, - std::vector& shape, - std::vector>& contiguity, - PrimDataType dtype = DataType::Float, - bool is_cpu = false, + const std::vector& shape, + std::vector> contiguity = {}, + const PrimDataType dtype = DataType::Float, + const bool is_cpu = false, std::vector stride_order = {}) -> Tensor { FUSER_PERF_SCOPE("FusionDefinition.define_tensor (default)"); NVF_CHECK( @@ -735,6 +733,16 @@ void initNvFuserPythonBindings(PyObject* module) { " was neither symbolic(-1), zero_element(0), broadcast(1), or static(>1)."); } + if (contiguity.empty()) { + for (const auto dim_size : shape) { + if (dim_size == 1) { + contiguity.emplace_back(std::nullopt); + } else { + contiguity.emplace_back(false); + } + } + } + Tensor out = self.defineTensor(shape.size()); self.defineRecord(new TensorRecord( {self.recordingState(out())}, @@ -747,7 +755,7 @@ void initNvFuserPythonBindings(PyObject* module) { return out; }, py::arg("shape"), - py::arg("contiguity"), + py::arg("contiguity") = py::list(), py::arg("dtype") = DataType::Float, py::arg("is_cpu") = false, py::arg("stride_order") = py::list(), @@ -865,9 +873,6 @@ void initNvFuserPythonBindings(PyObject* module) { fusion_def.def( "define_vector", [](FusionDefinition& self, size_t size) -> Vector { - NVF_CHECK( - size < 8, - "The specified vector size exceeds the max tensor size for nvfuser."); std::vector args; args.reserve(size); for (size_t i = 0; i < size; ++i) { @@ -1906,8 +1911,8 @@ void initNvFuserPythonBindings(PyObject* module) { self.validUse(), "Attempting to add to a completed definition!"); \ FusionDefinition* fd = self.fusion_definition; \ size_t ndims = 0; \ - std::vector axes(arg.dims); \ - std::iota(axes.begin(), axes.end(), 0); \ + std::vector dims(arg.dims); \ + std::iota(dims.begin(), dims.end(), 0); \ Tensor output = fd->defineTensor(ndims); \ fd->defineRecord(new ReductionOpRecord( \ {fd->recordingState(arg())}, \ @@ -1918,7 +1923,7 @@ void initNvFuserPythonBindings(PyObject* module) { const std::vector&, \ bool, \ DataType)>(op_name), \ - axes, \ + dims, \ false, \ dtype)); \ return output; \ @@ -1930,7 +1935,7 @@ void initNvFuserPythonBindings(PyObject* module) { op_str, \ [](FusionDefinition::Operators& self, \ Tensor arg, \ - int axis, \ + int dim, \ bool keepdim, \ PrimDataType dtype) -> Tensor { \ FUSER_PERF_SCOPE("Operators." op_str); \ @@ -1948,13 +1953,13 @@ void initNvFuserPythonBindings(PyObject* module) { const std::vector&, \ bool, \ DataType)>(op_name), \ - {axis}, \ + {dim}, \ keepdim, \ dtype)); \ return output; \ }, \ py::arg("arg"), \ - py::arg("axis"), \ + py::arg("dim"), \ py::arg("keepdim") = false, \ py::arg("dtype") = DataType::Null, \ py::return_value_policy::reference); \ @@ -1962,14 +1967,14 @@ void initNvFuserPythonBindings(PyObject* module) { op_str, \ [](FusionDefinition::Operators& self, \ Tensor arg, \ - const std::vector& axes, \ + const std::vector& dims, \ bool keepdim, \ PrimDataType dtype) -> Tensor { \ FUSER_PERF_SCOPE("Operators." op_str); \ NVF_CHECK( \ self.validUse(), "Attempting to add to a completed definition!"); \ FusionDefinition* fd = self.fusion_definition; \ - size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); \ + size_t ndims = keepdim ? arg.dims : (arg.dims - dims.size()); \ Tensor output = fd->defineTensor(ndims); \ fd->defineRecord(new ReductionOpRecord( \ {fd->recordingState(arg())}, \ @@ -1980,13 +1985,13 @@ void initNvFuserPythonBindings(PyObject* module) { const std::vector&, \ bool, \ DataType)>(op_name), \ - axes, \ + dims, \ keepdim, \ dtype)); \ return output; \ }, \ py::arg("arg"), \ - py::arg("axes"), \ + py::arg("dims"), \ py::arg("keepdim") = false, \ py::arg("dtype") = DataType::Null, \ py::return_value_policy::reference); @@ -2663,25 +2668,25 @@ void initNvFuserPythonBindings(PyObject* module) { "var", [](FusionDefinition::Operators& self, Tensor arg, - std::vector& axes, + std::vector& dims, int64_t correction, bool keepdim) -> Tensor { FUSER_PERF_SCOPE("Operators.var"); NVF_CHECK( self.validUse(), "Attempting to add to a completed definition!"); FusionDefinition* fd = self.fusion_definition; - size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); + size_t ndims = keepdim ? arg.dims : (arg.dims - dims.size()); Tensor output = fd->defineTensor(ndims); fd->defineRecord(new VarianceOpRecord( {fd->recordingState(arg())}, {fd->recordingState(output())}, - std::move(axes), + std::move(dims), correction, keepdim)); return output; }, py::arg("arg"), - py::arg("axes"), + py::arg("dims"), py::arg("correction"), py::arg("keepdim") = false, py::return_value_policy::reference); @@ -2689,26 +2694,26 @@ void initNvFuserPythonBindings(PyObject* module) { "var_mean", [](FusionDefinition::Operators& self, Tensor arg, - std::vector& axes, + std::vector& dims, int64_t correction, bool keepdim) -> decltype(auto) { FUSER_PERF_SCOPE("Operators.var_mean"); NVF_CHECK( self.validUse(), "Attempting to add to a completed definition!"); FusionDefinition* fd = self.fusion_definition; - size_t ndims = keepdim ? arg.dims : (arg.dims - axes.size()); + size_t ndims = keepdim ? arg.dims : (arg.dims - dims.size()); Tensor var = fd->defineTensor(ndims); Tensor mean = fd->defineTensor(ndims); fd->defineRecord(new VarianceMeanOpRecord( {fd->recordingState(arg())}, {fd->recordingState(var()), fd->recordingState(mean())}, - std::move(axes), + std::move(dims), correction, keepdim)); return std::make_tuple(var, mean); }, py::arg("arg"), - py::arg("axes"), + py::arg("dims"), py::arg("correction") = 1, py::arg("keepdim") = false, py::return_value_policy::reference); diff --git a/csrc/scheduler/mark_aliases.cpp b/csrc/scheduler/mark_aliases.cpp index 10292c8018b..b061a6bddd9 100644 --- a/csrc/scheduler/mark_aliases.cpp +++ b/csrc/scheduler/mark_aliases.cpp @@ -38,7 +38,8 @@ void markAliases(Fusion* fusion) { continue; } - fusion->aliasOutputToInput(out, aliased_io, AliasType::PointerArithmetic); + fusion->aliasOutputToInput( + out, aliased_io, AllocationType::PointerArithmetic); vlog( "Marked ", ir_utils::varName(out), diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 83b310073ff..7c7b354cc60 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -1100,8 +1100,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // (iS) iBx iBy iTz iTy iS iS iS iTx iS rBz // // This reordering step lets us inline all but the last dim MNi3 (position - // nbatch + 7) which might be vectorized for the epilogue but which we - // can't vectorize for the gridReduce. + // nbatch + 7) which might be vectorized. // // NOTE: we need to do this reorder after the propagation above so that it // doesn't get reset. @@ -1115,6 +1114,8 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { {num_batch_dims + 8, num_batch_dims + 7}, {num_batch_dims + 9, num_batch_dims + 8}, }); + // Vectorize inner-most dimension + splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize); } // auto inline for all tensors except register tensors diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 36f8da4a424..2ec80a11a36 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -36,22 +36,18 @@ inline mma_utils::MmaDataTypes getMmaDataTypes( return mma_utils::MmaDataTypes{a_type, b_type, c_type}; } -std::pair generateSharedMemoryEpilogueHeuristics( +//! Return sizes of smem_a, smem_b, smem_c in bytes +std::tuple computeSharedMemorySizes( const MatMulTileOptions& gemm_tile, - const int smem_double_buffer_stage, - const MmaDataTypes& data_types, - bool smem_a_reuse_guaranteed, - bool smem_b_reuse_guaranteed, - bool ignore_occupancy_drop) { + const MatmulParams::DoubleBufferOptions& double_buffer_options, + const MmaDataTypes& data_types) { const auto properties = at::cuda::getCurrentDeviceProperties(); - const size_t device_smem_limit = properties->sharedMemPerBlockOptin; - const size_t shared_memory_overhead = properties->reservedSharedMemPerBlock; - const size_t shared_memory_available = - device_smem_limit - shared_memory_overhead; auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile; - const auto threads_per_block = - warp_dims.m * warp_dims.n * warp_dims.k * properties->warpSize; + + int ab_factor = double_buffer_options.double_buffer_smem_write + ? double_buffer_options.smem_double_buffer_stage + : 1; // see scheduleContiguousVectorLoad const int vector_word = 8; @@ -59,15 +55,61 @@ std::pair generateSharedMemoryEpilogueHeuristics( properties->warpSize * vector_word; const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k; const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k; - const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) * - round_to_factor * smem_double_buffer_stage) * + const size_t smem_a = + (size_t)(ceilDiv(mk, round_to_factor) * round_to_factor * ab_factor) * dataTypeSize(data_types[0]); - const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) * - round_to_factor * smem_double_buffer_stage) * + const size_t smem_b = + (size_t)(ceilDiv(nk, round_to_factor) * round_to_factor * ab_factor) * dataTypeSize(data_types[1]); const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) * dataTypeSize(data_types[2]); + return {smem_a, smem_b, smem_c}; +} + +int64_t computeExpectedSharedMemoryUsage( + const MatmulParams& params, + const MmaDataTypes& data_types, + bool smem_a_reuse_guaranteed, + bool smem_b_reuse_guaranteed) { + const auto [smem_a, smem_b, smem_c] = computeSharedMemorySizes( + params.tile_sizes, params.double_buffer_options, data_types); + + if (params.use_smem_epilogue) { + if (params.promote_prologue_smem_reuse) { + return (int64_t)std::max( + smem_c + (smem_a_reuse_guaranteed ? 0 : smem_a) + + (smem_b_reuse_guaranteed ? 0 : smem_b), + smem_a + smem_b); + } else { + return (int64_t)(smem_a + smem_b + smem_c); + } + } else { + return (int64_t)(smem_a + smem_b); + } +} + +std::pair generateSharedMemoryEpilogueHeuristics( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage, + const MmaDataTypes& data_types, + bool smem_a_reuse_guaranteed, + bool smem_b_reuse_guaranteed, + bool ignore_occupancy_drop) { + const auto properties = at::cuda::getCurrentDeviceProperties(); + const size_t device_smem_limit = properties->sharedMemPerBlockOptin; + const size_t shared_memory_overhead = properties->reservedSharedMemPerBlock; + const size_t shared_memory_available = + device_smem_limit - shared_memory_overhead; + + // Create a temporary DoubleBufferOptions with full double buffering, for + // estimating shared memory size. + MatmulParams::DoubleBufferOptions double_buffer_options{ + true, true, smem_double_buffer_stage}; + + const auto [smem_a, smem_b, smem_c] = + computeSharedMemorySizes(gemm_tile, double_buffer_options, data_types); + // NOTE: we can simply add these sizes since they should be integer multiples // of 16 bytes, so they will automatically be aligned. This may change with // FP8, in which case the expressions below should be updated to insert @@ -96,6 +138,9 @@ std::pair generateSharedMemoryEpilogueHeuristics( // use additional shared memory for epilogue if occupancy is not changed. // occupancy is estimated using register and shared memory usage. + auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile; + const auto threads_per_block = + warp_dims.m * warp_dims.n * warp_dims.k * properties->warpSize; const auto threads_per_sm = getThreadsPerSMGivenRegPerThread(255); const auto blocks_per_sm_by_register = threads_per_sm / threads_per_block; const auto blocks_per_sm_without_smem_epilogue = std::min( diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index dccbc3c2639..ac7feffe36a 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -389,6 +390,16 @@ class CombineMulSum : public IterVisitor { bool is_valid_ = false; }; +//! Compute the amount of shared memory we expect to need. The actual amount +//! allocated will be determined by aliasing (see alias_memory.cpp). This +//! function is useful for testing that we provide accurate information to our +//! heuristics. +int64_t computeExpectedSharedMemoryUsage( + const MatmulParams& params, + const MmaDataTypes& data_types, + bool smem_a_reuse_guaranteed = false, + bool smem_b_reuse_guaranteed = false); + } // namespace mma_utils } // namespace nvfuser diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index fe3a5b6f457..eae9d4dc7bd 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -127,6 +127,30 @@ void TransposeScheduler::computeHeuristics( namespace { +// If a fusion is segmented, the segmenter will create fusions whose inputs +// contain reduction IterDomains. These reduction IterDomains on input +// TensorViews does not have any meaning, and should just be left untouched. See +// https://github.com/NVIDIA/Fuser/issues/1659#issuecomment-1907053830 +// +// This function checks the inner `n` iterdomain and reorder reduction +// iterdomain to the beginning. +void moveReductionsOut(TensorView* tv, int n) { + if (!tv->isFusionInput()) { + return; + } + + std::unordered_map old2new; + + int target = 0; + for (int i = 0; i < n; i++) { + if (tv->axis(-1 - i)->isReduction()) { + old2new[-1 - i] = target++; + } + } + + tv->reorder(old2new); +} + // TransposeViewPropagator doesn't propagate anything. It simply walks across // the path of potential propagation checking if there's any incompatible // propagation that would not be resolved. @@ -1236,6 +1260,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { int pos = (int)reference2->nDims() - 2; // [..., tile1, tile2] + moveReductionsOut(reference2, 2); reference2->merge(pos); reference2->split(pos, params.vectorize_factor2); reference2->split(pos, params.getThreadsPerBlock()); @@ -1321,6 +1346,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { reference1->reorder({{-2, -1}}); // [..., tile2, tile1] pos = (int)reference1->nDims() - 2; + moveReductionsOut(reference1, 2); reference1->merge(pos); reference1->split(pos, params.vectorize_factor1); reference1->split(pos, params.getThreadsPerBlock()); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 5d9eb0c0d40..4cc2cde02b8 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1123,9 +1123,10 @@ std::vector cacheInputs(Fusion* fusion, bool unroll) { auto in_tvs = ir_utils::filterByType(fusion->inputs()); for (auto tv : in_tvs) { if (tv->uses().empty() || ir_utils::isTorchGatherLookupTv(tv) || - ir_utils::isSelectInput(tv) || ir_utils::isIndexSelectLookupTv(tv)) { - // Right now, tensors that are input to the select op can't be cached as - // they must be in global memory. + ir_utils::isIndexSelectLookupTv(tv) || + ir_utils::isTvUsedByOpsOfType(tv)) { + // Right now, tensors that are input to the slice, select, and pad ops + // can't be cached as they must be in global memory. continue; } auto cached_tv = tv->cacheAfter(); diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index b2a1e49ad0f..824fa2b6520 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -436,6 +436,9 @@ ContiguousInnerDimensionsMapper::computeInfoC2P( std::shared_ptr from_info) { auto from_ids = std::dynamic_pointer_cast(from_info) ->mapped_root_ids_; + // When we propagate, we should check the resolved broadcast in the order of + // mapped from_ids. + // // If we have a case where we have a concretized broadcast that's being // tracked in a consumer but not concretized in the producer we should break // off the dimensions connected to the left of that dimension. So if we have: @@ -445,49 +448,47 @@ ContiguousInnerDimensionsMapper::computeInfoC2P( // T3[i0, i1, i2] = T1 + T2 // and we're propogating from T3 with {i0, i1, i2} // When we go from T3 to T0, we don't have any mechanism to understand that i0 - // and i2 are not contiguous in the original domain of T3. It's not ideal with - // transpose, but when this happens we'll clear all dimensions mapped left of - // the concretized broadcast. - // So if we have: - // T0[i1, i2] - // T1[b0, i1, i2] = broadcast(T0) - // T2[i1, b0, i2] = transpose(T1) - // T3[i1, i0, i2] - // T4[i1, i0, i2] = T2 + T3 - // T5[i0, i1, i2] = transpose(T4) - // Then i1 and i2 are contiguous in both T0 and T5, but due to the realization - // of the broadcast on T4 we will have removed i1 from the mapped set. + // and i2 are not contiguous in the original domain of T3. + // + // Another example is that, if the last broadcast dimension resolved in + // consumers root domain is mapped for vectorization, the merge order in + // the vectorization axes matters. + // + // T0[i0, i1] + // T1[i0, i1, b2] = broadcast(T0) + // T2[i0, i1, i3] + // T3[i0, i1, i2] = T1 + T2 + // + // If the mapped ids are {i0, i2, i1}, when propagating from T3 to T1, the + // resolved broadcast iterdomain is `i2`/`b2`, which would give clear_pos=1. + // So we'll skip all from_ids with index < clear_pos. see issue: + // https://github.com/NVIDIA/Fuser/issues/1567#issuecomment-1894605385 PairwiseRootDomainMap root_map(to, from); auto c2p_map = root_map.mapConsumerToProducer(); // Id's in consumer to clear from the mapped set due to broadcast // concretization. std::unordered_set consumer_ids_to_clear; + size_t clear_pos = 0; if (to->hasBroadcast()) { - // Find the last broadcast dimension resolved in consumers root domain - int clear_pos = -1; - for (auto i : c10::irange(from->getRootDomain().size())) { - auto c_id = from->getRootDomain()[i]; + // Find the last broadcast dimension resolved in consumers through from_ids + for (int i = (int)from_ids.size() - 1; i >= 0; i--) { + auto c_id = from_ids[i]; auto c_it = c2p_map.find(c_id); if (c_it == c2p_map.end()) { continue; } auto p_id = c_it->second; if ((!c_id->isBroadcast()) && p_id->isBroadcast()) { - clear_pos = (int)i; + clear_pos = (size_t)i + 1; + break; } } - // Clear everything to the left of the inner most resolved broadcast - // dimension, including the broadcasted domain. - if (clear_pos >= 0) { - consumer_ids_to_clear.insert( - from->getRootDomain().begin(), - from->getRootDomain().begin() + clear_pos + 1); - } } std::vector producer_rfactor_ids; - for (auto from_id : from_ids) { + for (auto i : c10::irange(clear_pos, from_ids.size())) { + auto from_id = from_ids[i]; auto c2p_it = c2p_map.find(from_id); if (c2p_it != c2p_map.end() && consumer_ids_to_clear.find(c2p_it->first) == diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 5b39fbe5f97..5f4bc5eca82 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -1285,9 +1286,12 @@ TensorView* TensorView::cacheAfter(LoadStoreOpType op_type, CacheOp cache_op) { !hasComputeAt(), "Caching computed-at tensors is not allowed. Apply caching before computeAt."); + bool is_allowed_op = + !ir_utils::isTvUsedByOpsOfType(this) && + !ir_utils::isIndexSelectLookupTv(this); NVF_CHECK( - !ir_utils::isSelectInput(this) && !ir_utils::isIndexSelectLookupTv(this), - "Right now, caching tensors that are input to the select op is not allowed as they must be in global memory.") + is_allowed_op, + "Right now, caching tensors that are input to the select/slice/pad ops are not allowed as they must be in global memory.") // It also did additional transformation when this tensor is an // input and the outputs of its consumers have computeAt. Make sure diff --git a/csrc/utils.h b/csrc/utils.h index 9c8f3f6cddf..b47c8164cc5 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -26,6 +27,11 @@ #include #include +#define NVF_TORCH_VERSION_GREATER(major, minor, patch) \ + TORCH_VERSION_MAJOR > major || \ + (TORCH_VERSION_MAJOR == major && TORCH_VERSION_MINOR > minor || \ + (TORCH_VERSION_MINOR == minor && TORCH_VERSION_PATCH > patch)) + //! IR header hierarchy //! 1. ** utils.h ** - PolymorphicBase and NonCopyable //! 2. ir/base_nodes.h - Statement, Expr, and Val diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 3e9324a0c19..f5578a76648 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -99,6 +99,58 @@ std::vector ValGraph::inputGroups(const ExprGroup& expr) const { return input_groups; } +ValGroups ValGraph::getTerminatingInputs() const { + // Initialize vals to traverse + ValGroups all_vals{ + disjointValSets().disjointSets().begin(), + disjointValSets().disjointSets().end()}; + + // Initialize exprs to traverse + ExprGroups all_exprs{ + disjointExprSets().disjointSets().begin(), + disjointExprSets().disjointSets().end()}; + + // Grab all vals that are not input, i.e., having a defining expr + // within all_exprs. + // + // Note that an input Val group may be mapped with an output + // group. For example, the AlmostExact graph maps an input of split + // with the outer output if the split factor is one. Such a Val + // group is considered a terminating input as long as the input has + // no defining expression. This is for the use case of + // ValGraphVisitor. + // + // Example: + // + // [i0, i1] + // split by 1 + // [i0/1, 1, i1] + // merge + // [i0/1, 1*i1] + // + // Here, i0 and i0/1 would create a Val group of {i0, i0/1} in the + // AlmostExact graph. This group has a defining expression of the + // split, but since it's a cyclic dependency, we ignore the + // expression and consider the Val group a terminating input. + + ValGroups not_inputs; + for (const ExprGroup& expr_group : all_exprs) { + const std::vector input_groups = inputGroups(expr_group); + const std::vector output_groups = outputGroups(expr_group); + std::unordered_set input_set{ + input_groups.begin(), input_groups.end()}; + + for (const ValGroup& output_group : output_groups) { + if (input_set.count(output_group)) { + continue; + } + not_inputs.pushBack(output_group); + } + } + + return all_vals.computeSubtract(not_inputs); +} + ExprGroups ValGraph::allUsesOf(const ValGroups& of) const { DequeOfExprGroup to_visit; for (const ValGroup& of_val_group : of) { diff --git a/csrc/val_graph.h b/csrc/val_graph.h index f7a20e46bdb..72182e3f2f8 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -95,6 +95,9 @@ class ValGraph { std::vector outputGroups(const ExprGroup& expr) const; std::vector inputGroups(const ExprGroup& expr) const; + // Return Val groups that have no definition. + ValGroups getTerminatingInputs() const; + // Recursively traverses uses of the IdGroups in 'of' and returns all // ExprGroups that have a use in their definition of provided of IdGroups. ExprGroups allUsesOf(const ValGroups& of) const; diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp new file mode 100644 index 00000000000..7d6646c747e --- /dev/null +++ b/csrc/val_graph_visitor.cpp @@ -0,0 +1,152 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include + +namespace nvfuser { + +void ValGraphVisitor::traverse() { + const ValGroups terminating_inputs = graph().getTerminatingInputs(); + std::deque to_visit_vals( + terminating_inputs.begin(), terminating_inputs.end()); + ValGroups visited_vals; + + std::deque to_visit_exprs; + ExprGroups visited_exprs; + + auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { + const auto inp_groups = graph().inputGroups(expr_group); + return std::all_of( + inp_groups.begin(), inp_groups.end(), [&](ValGroup val_group) { + return visited_vals.has(val_group) || val_group->empty(); + }); + }; + + // If any input of the def expr is mapped with the val + // group itself, i.e., a trivial expr, allow visiting the + // val group first. The trivial expr group will be visited + // after the val group. + // + // Example: + // + // [i0, 1] + // merge + // [i0*1] + // map i0 and i0*1 + // ValGroups: {{i0, i0*1}, {1}} + // + // Then, {i0, i0*1} and {1} would be visited first, then the merge + // expr group would be visited. {i0, i0*1} is also an output group + // of the merge but since it's already in the visited set, it would + // not be visited again. + // + // See also IdModelTest.ValGraphStmtSort3 for a concrete example. + auto is_val_ready = [&](const ValGroup& val_group) -> bool { + if (terminating_inputs.has(val_group)) { + return true; + } + const ExprGroups& unique_defs = graph().getDefinitions(val_group); + return std::all_of( + unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { + if (expr_group->empty() || visited_exprs.has(expr_group)) { + return true; + } + // Handle ExprGroups that return one or some of its input ValGroups as + // output. This expr_group is not visited yet, which means there're + // input ValGroups that are not yet visited. If those not-visited + // inputs are actually the same as val_group, visit val_group at this + // point to resolve the circular dependency. + for (const ValGroup& input_group : graph().inputGroups(expr_group)) { + if (input_group != val_group && !visited_vals.has(input_group) && + input_group->empty()) { + return false; + } + } + return true; + }); + }; + + // Detect if nothing has been processed which would put us in an infinite + // loop + bool something_was_processed = false; + + do { + something_was_processed = false; + + // Process expressions first as all definitions of vals have to be + // processed before we can process that val. + + while (!to_visit_exprs.empty()) { + ExprGroup current_expr_group = to_visit_exprs.front(); + to_visit_exprs.pop_front(); + NVF_ERROR(!current_expr_group->empty()); + if (visited_exprs.has(current_expr_group)) { + continue; + } + + if (is_expr_ready(current_expr_group)) { + handle(current_expr_group); + + something_was_processed = true; + visited_exprs.pushBack(current_expr_group); + + for (const ValGroup& output_group : + graph().outputGroups(current_expr_group)) { + to_visit_vals.push_back(output_group); + } + } + } + + std::deque still_to_visit_vals; + while (!to_visit_vals.empty()) { + auto current_val_group = to_visit_vals.front(); + to_visit_vals.pop_front(); + NVF_ERROR(!current_val_group->empty()); + if (visited_vals.has(current_val_group)) { + continue; + } + + if (is_val_ready(current_val_group)) { + handle(current_val_group); + + something_was_processed = true; + visited_vals.pushBack(current_val_group); + + for (const ExprGroup& use_group : graph().getUses(current_val_group)) { + to_visit_exprs.push_back(use_group); + } + } else { + still_to_visit_vals.push_back(current_val_group); + } + } + + std::swap(to_visit_vals, still_to_visit_vals); + + } while (something_was_processed); + + if (!to_visit_vals.empty()) { + std::stringstream ss; + ss << "The graph has an infinite loop. The following Vals should be visited but are never ready:"; + for (const ValGroup& vg : to_visit_vals) { + ss << " " << nvfuser::toString(vg); + } + NVF_ERROR(false, ss.str()); + } + + if (!to_visit_exprs.empty()) { + std::stringstream ss; + ss << "The graph has an infinite loop. The following Exprs should be visited but are never ready:"; + for (const ExprGroup& eg : to_visit_exprs) { + ss << " " << nvfuser::toString(eg); + } + NVF_ERROR(false, ss.str()); + } +} + +} // namespace nvfuser diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h new file mode 100644 index 00000000000..54a5ed253b8 --- /dev/null +++ b/csrc/val_graph_visitor.h @@ -0,0 +1,121 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Iterates through a Val Graph in topological order, calling handle on +// all Val and all Expr groups in a forward topological order. +// +// Warning: A ValGraph is not guaranteed to be a DAG. In fact, the +// AlmostExact and Permissive graphs would have cycles with a ValGroup +// and an ExprGroup. For example: +// +// [i0, 1] +// merge +// [i0*1] +// Current ValGroups: {{i0}, {1}, {i0*1}} +// map i0 and i0*1 as they effectively have the same extent +// Final ValGroups: {{i0, i0*1}, {1}} +// +// Here, the merge expr is the user of i0 and the definition of +// i0*1. Since i0 and i0*1 are mapped, the dependency chain looks +// like: +// +// {i0, i0*1} ----> {merge} ----> {i0, i0*1} +// use def +// +// These ExprGroups are called trivial ExprGroups (see also +// ValGraph::isTrivialExprGroup). +// +// Strictly speaking, these cycles mean there's no valid topological +// order anymore. In our use cases for IdModel, however, it's likely +// sufficient to return an ordering such as: +// +// {i0, i0*1} -> {merge} +// +// I.e., we visit {i0, i0*1} first even though {merge} is technically +// a definition. +// +// Another alternative may be simply giving up when such a cycle is +// detected, which may be more preferrable as it would be less +// confusing. At this moment, this visitor is only used with graphs +// with no such cycle. Should be revisited when necessary. +// +// Warning: This is not a great iterator if there's a desire to minimize paths +// traveled to simply visit all ValGroups in order. See ExprsBetween to see how +// we might minimize paths. +class ValGraphVisitor { + public: + ValGraphVisitor() = delete; + + ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete; + + ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete; + + virtual ~ValGraphVisitor() = default; + + protected: + ValGraphVisitor(const ValGraph& val_graph) : val_graph_(val_graph) {} + + ValGraphVisitor(const ValGraphVisitor& other) = default; + + ValGraphVisitor(ValGraphVisitor&& other) = default; + + virtual void handle(const ValGroup& val_group) = 0; + virtual void handle(const ExprGroup& expr_group) = 0; + + void traverse(); + + const ValGraph& graph() { + return val_graph_; + }; + + private: + const ValGraph& val_graph_; +}; + +// Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor. +class ValGraphStmtSort : public ValGraphVisitor { + public: + ValGraphStmtSort(const ValGraph& val_graph) : ValGraphVisitor(val_graph) { + ValGraphVisitor::traverse(); + } + + // Return non-reference so that code like below can work + // for (auto expr_group: ValGraphStmtSort(graph).exprs()) + ExprGroups exprs() const { + return sorted_exprs_; + } + + ValGroups vals() const { + return sorted_vals_; + } + + ~ValGraphStmtSort() override = default; + + protected: + using ValGraphVisitor::handle; + + void handle(const ValGroup& val_group) override { + sorted_vals_.pushBack(val_group); + } + + void handle(const ExprGroup& expr_group) override { + sorted_exprs_.pushBack(expr_group); + } + + ExprGroups sorted_exprs_; + ValGroups sorted_vals_; +}; + +} // namespace nvfuser diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 2e762cf8d48..3719f1c5f2c 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -39,6 +39,13 @@ def enable_automatic_serialization(): atexit.register(_C.serialize) + # A separate process is created for each device in a distributed setting. + # Each FusionCache becomes associated with a single device. + # Automatic serialization saves a separate cache for each device. + # Set the FusionCache id to the ddp local rank. + ddp_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + _C.FusionCache.get(max_fusions := 8192, ddp_local_rank) + # Unregister automatic serialization of Nvfuser cache hierarchy and cuda kernels. def disable_automatic_serialization(): diff --git a/python_benchmarks/conftest.py b/python_benchmarks/conftest.py index 7748c22c077..fd545460f55 100644 --- a/python_benchmarks/conftest.py +++ b/python_benchmarks/conftest.py @@ -6,15 +6,24 @@ def pytest_addoption(parser): parser.addoption( "--disable-validation", action="store_true", - default=False, help="Disable output validation in benchmarks.", ) parser.addoption( "--disable-benchmarking", action="store_true", - default=False, help="Disable benchmarking.", ) + parser.addoption( + "--benchmark-eager", + action="store_true", + help="Benchmarks torch eager mode.", + ) + + parser.addoption( + "--benchmark-torchcompile", + action="store_true", + help="Benchmarks torch.compile mode.", + ) @pytest.fixture @@ -33,3 +42,37 @@ def pytest_make_parametrize_id(val): def pytest_benchmark_update_machine_info(config, machine_info): machine_info.update(DEVICE_PROPERTIES) + + +def pytest_collection_modifyitems(session, config, items): + """ + The baseline benchmarks use `compile` parameter: + compile = false: Eager mode benchmark + compile = true: torch.compile benchmark + """ + run_eager = config.getoption("--benchmark-eager") + run_torchcompile = config.getoption("--benchmark-torchcompile") + + if not run_eager: + skip_eager = pytest.mark.skip(reason="need --benchmark-eager option to run") + for item in items: + # If the benchmark has compile=False parameter (eager mode), skip it. + if ( + hasattr(item, "callspec") + and "compile" in item.callspec.params + and not item.callspec.params["compile"] + ): + item.add_marker(skip_eager) + + if not run_torchcompile: + skip_torchcompile = pytest.mark.skip( + reason="need --benchmark-torchcompile option to run" + ) + for item in items: + # If the benchmark has compile=True parameter (torch.compile mode), skip it. + if ( + hasattr(item, "callspec") + and "compile" in item.callspec.params + and item.callspec.params["compile"] + ): + item.add_marker(skip_torchcompile) diff --git a/python_benchmarks/global_params.py b/python_benchmarks/global_params.py index 72a15c805e4..fbc610c53bb 100644 --- a/python_benchmarks/global_params.py +++ b/python_benchmarks/global_params.py @@ -4,6 +4,12 @@ from .core import DEVICE_PROPERTIES import numpy as np import itertools +import os + +# BENCHMARK_MODE = weekly/nightly. +BENCHMARK_MODE = os.getenv("BENCHMARK_MODE") +if not BENCHMARK_MODE: + BENCHMARK_MODE = "nightly" # Datatypes to benchmark FLOAT_DTYPES = [torch.float32] @@ -44,6 +50,18 @@ # Utility function to generate input sizes for benchmarks def generate_input_sizes(dims: Union[int, List] = 2) -> List[Tuple]: + """ + The weekly vs nightly input ranges only differ for 2D inputs currently. + Nightly input range: + Batch size: [16->16384] Hidden size: [768, 4*18432] (step size = 256) + Weekly input range: + Batch size: + [16]: Latency bound state + [512, 1024]: Just filled the machine + [16384]: Steady state (full machine) + Hidden size: [768, 4*18432] (step size = 8) + Note: The hidden size is restricted to 2 * 18432 for the batch size 16384 to avoid OOM. + """ inputs = [] if isinstance(dims, int): dims = [dims] @@ -51,12 +69,16 @@ def generate_input_sizes(dims: Union[int, List] = 2) -> List[Tuple]: for dim in dims: if dim == 2: input_ranges = [] - step_size = 256 + step_size = 256 # max_batch_range: set according to max size that fits in GPU memory batch_range = [2**i for i in range(4, 14)] # {16, 8192} - # max_hidden_size = 4 * d_model_max (max hidden size in feedforward layers) + if BENCHMARK_MODE == "weekly": + step_size = 8 + batch_range = [16, 512, 1024] + + # max_hidden_size = 4 * d_model_max (max hidden size in feedforward layers) # NOTE: Numpy arrays are not JSON serializable so convert them to enable storing benchmark data. hidden_range = np.arange( D_MODEL_MIN, 4 * D_MODEL_MAX + 1, step_size diff --git a/python_benchmarks/normalization.py b/python_benchmarks/normalization.py index 9988ba34b2b..8c051554d5a 100644 --- a/python_benchmarks/normalization.py +++ b/python_benchmarks/normalization.py @@ -47,7 +47,7 @@ def norm_fwd_fusion( weight = fd.ops.cast(weight, dtype=DataType.Float) bias = fd.ops.cast(bias, dtype=DataType.Float) - var, mean = fd.ops.var_mean(input, axes=reduction_axes, correction=0, keepdim=False) + var, mean = fd.ops.var_mean(input, dims=reduction_axes, correction=0, keepdim=False) eps = fd.define_scalar(eps, dtype=DataType.Double) var_eps = fd.ops.add(var, eps) @@ -155,10 +155,10 @@ def norm_bwd_fusion( mean = fd.ops.broadcast(mean, bcast_mask) - grad_sum = fd.ops.sum(grad, axes=reduction_axes, keepdim=False) + grad_sum = fd.ops.sum(grad, dims=reduction_axes, keepdim=False) x_sub_mean = fd.ops.sub(input, mean) - dot_p = fd.ops.sum(fd.ops.mul(grad, x_sub_mean), axes=reduction_axes, keepdim=False) + dot_p = fd.ops.sum(fd.ops.mul(grad, x_sub_mean), dims=reduction_axes, keepdim=False) grad_mean = fd.ops.broadcast(fd.ops.mul(grad_sum, norm), bcast_mask) proj_scale = fd.ops.mul(fd.ops.mul(dot_p, norm), fd.ops.mul(invstd, invstd)) diff --git a/python_benchmarks/test_dropout_layernorm_bwd.py b/python_benchmarks/test_dropout_layernorm_bwd.py index 5e8737d21fe..db4a8cb0c12 100644 --- a/python_benchmarks/test_dropout_layernorm_bwd.py +++ b/python_benchmarks/test_dropout_layernorm_bwd.py @@ -53,26 +53,26 @@ def dropout_layernorm_bwd_fusion( T25 = fd.ops.broadcast_in_dim(T3, shape=V19, broadcast_dims=[0, 1]) T26 = fd.ops.mul(T21, T25) T30 = fd.ops.broadcast_in_dim(T5, shape=V19, broadcast_dims=[1]) - T35 = fd.ops.sum(T4, axes=[0], keepdim=False, dtype=DataType.Null) + T35 = fd.ops.sum(T4, dims=[0], keepdim=False, dtype=DataType.Null) T37 = fd.ops.mul(T4, T30) T38 = fd.ops.mul(T4, T26) - T39 = fd.ops.sum(T38, axes=[0], keepdim=False, dtype=DataType.Null) + T39 = fd.ops.sum(T38, dims=[0], keepdim=False, dtype=DataType.Null) T41 = fd.ops.mul(T37, T25) T42 = fd.ops.mul(T37, T21) - T43 = fd.ops.sum(T42, axes=[1], keepdim=False, dtype=DataType.Null) + T43 = fd.ops.sum(T42, dims=[1], keepdim=False, dtype=DataType.Null) T47 = fd.ops.broadcast_in_dim(T43, shape=V15, broadcast_dims=[0]) T48 = fd.ops.neg(T41) - T49 = fd.ops.sum(T48, axes=[1], keepdim=False, dtype=DataType.Null) + T49 = fd.ops.sum(T48, dims=[1], keepdim=False, dtype=DataType.Null) T53 = fd.ops.broadcast_in_dim(T49, shape=V15, broadcast_dims=[0]) S54 = fd.define_scalar(-0.500000, dtype=DataType.Double) T55 = fd.ops.mul(S54, T47) S56 = fd.define_scalar(3.00000, dtype=DataType.Double) T57 = fd.ops.pow(T3, S56) T58 = fd.ops.mul(T55, T57) - T61 = fd.ops.sum(T53, axes=[1], keepdim=False, dtype=DataType.Null) - T62 = fd.ops.sum(T58, axes=[1], keepdim=False, dtype=DataType.Null) + T61 = fd.ops.sum(T53, dims=[1], keepdim=False, dtype=DataType.Null) + T62 = fd.ops.sum(T58, dims=[1], keepdim=False, dtype=DataType.Null) T66 = fd.ops.broadcast_in_dim(T62, shape=V15, broadcast_dims=[0]) T70 = fd.ops.broadcast_in_dim(T66, shape=V19, broadcast_dims=[0, 1]) T74 = fd.ops.broadcast_in_dim(T2, shape=V15, broadcast_dims=[0]) diff --git a/python_benchmarks/test_dropout_layernorm_fwd.py b/python_benchmarks/test_dropout_layernorm_fwd.py index 4136b74f86e..c80aaea7a25 100644 --- a/python_benchmarks/test_dropout_layernorm_fwd.py +++ b/python_benchmarks/test_dropout_layernorm_fwd.py @@ -38,7 +38,7 @@ def dropout_layernorm_fwd_fusion( T15 = fd.ops.mul(T13, S14) T16 = fd.ops.add(T2, T15) # Layernorm - T17, T18 = fd.ops.var_mean(T16, axes=[1], correction=0, keepdim=False) + T17, T18 = fd.ops.var_mean(T16, dims=[1], correction=0, keepdim=False) V21 = fd.define_vector([T2.size(0), 1], dtype=DataType.Int) T22 = fd.ops.broadcast_in_dim(T17, shape=V21, broadcast_dims=[0]) T26 = fd.ops.broadcast_in_dim(T18, shape=V21, broadcast_dims=[0]) diff --git a/python_benchmarks/test_dropout_rmsnorm_bwd.py b/python_benchmarks/test_dropout_rmsnorm_bwd.py index 73d966b46a4..0f5046b5de0 100644 --- a/python_benchmarks/test_dropout_rmsnorm_bwd.py +++ b/python_benchmarks/test_dropout_rmsnorm_bwd.py @@ -54,7 +54,7 @@ def dropout_rmsnorm_bwd_fusion( T30 = fd.ops.mul(T8, T23) T31 = fd.ops.mul(T8, T27) - T32 = fd.ops.sum(T30, axes=[0], keepdim=False, dtype=DataType.Null) + T32 = fd.ops.sum(T30, dims=[0], keepdim=False, dtype=DataType.Null) T35 = fd.ops.mul(T31, T22) T36 = fd.ops.neg(T31) @@ -63,7 +63,7 @@ def dropout_rmsnorm_bwd_fusion( T39 = fd.ops.pow(T20, S38) T40 = fd.ops.reciprocal(T39) T41 = fd.ops.mul(T37, T40) - T42 = fd.ops.sum(T41, axes=[1], keepdim=False, dtype=DataType.Null) + T42 = fd.ops.sum(T41, dims=[1], keepdim=False, dtype=DataType.Null) V60 = fd.define_vector([T5.size(0), 1], dtype=DataType.Int) T47 = fd.ops.broadcast_in_dim(T42, shape=V60, broadcast_dims=[0]) @@ -73,7 +73,7 @@ def dropout_rmsnorm_bwd_fusion( T52 = fd.ops.mul(T47, T51) S55 = fd.ops.reciprocal(T5.size(1)) T56 = fd.ops.mul(T52, S55) - T57 = fd.ops.sum(T56, axes=[1], keepdim=False, dtype=DataType.Null) + T57 = fd.ops.sum(T56, dims=[1], keepdim=False, dtype=DataType.Null) T61 = fd.ops.broadcast_in_dim(T57, shape=V60, broadcast_dims=[0]) T65 = fd.ops.broadcast_in_dim(T61, shape=V19, broadcast_dims=[0, 1]) diff --git a/python_benchmarks/test_dropout_rmsnorm_fwd.py b/python_benchmarks/test_dropout_rmsnorm_fwd.py index 337e8ff76ff..dd4454d9d0d 100644 --- a/python_benchmarks/test_dropout_rmsnorm_fwd.py +++ b/python_benchmarks/test_dropout_rmsnorm_fwd.py @@ -43,7 +43,7 @@ def dropout_rmsnorm_fwd_fusion( T15 = fd.ops.add(T0, T14) S16 = fd.define_scalar(2.00000, dtype=DataType.Double) T17 = fd.ops.pow(T15, S16) - T18 = fd.ops.sum(T17, axes=[1], keepdim=False, dtype=DataType.Null) + T18 = fd.ops.sum(T17, dims=[1], keepdim=False, dtype=DataType.Null) V21 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) T22 = fd.ops.broadcast_in_dim(T18, shape=V21, broadcast_dims=[0]) diff --git a/python_benchmarks/test_gelu_bwd_reduction.py b/python_benchmarks/test_gelu_bwd_reduction.py index 71e9a363dde..1590b73493d 100644 --- a/python_benchmarks/test_gelu_bwd_reduction.py +++ b/python_benchmarks/test_gelu_bwd_reduction.py @@ -47,7 +47,7 @@ def gelu_bwd_reduction_fusion( T18 = fd.ops.mul(T17, S2) T19 = fd.ops.add(T16, T18) T20 = fd.ops.mul(grad, T19) - T21 = fd.ops.sum(T20, axes=[reduction_axis], keepdim=False) + T21 = fd.ops.sum(T20, dims=[reduction_axis], keepdim=False) if dtype in PROMOTE_DTYPES: T21 = fd.ops.cast(T21, dtype=dtype) fd.add_output(T21) diff --git a/python_benchmarks/test_huggingface_attn_bwd.py b/python_benchmarks/test_huggingface_attn_bwd.py index 108b9fd161e..c8bdb6d666a 100644 --- a/python_benchmarks/test_huggingface_attn_bwd.py +++ b/python_benchmarks/test_huggingface_attn_bwd.py @@ -37,7 +37,7 @@ def huggingface_attn_bwd_fusion( T10 = fd.ops.mul(T5, S0) T11 = fd.ops.mul(T10, T7) T13 = fd.ops.mul(T6, T11) - T14 = fd.ops.sum(T13, axes=[2], keepdim=False, dtype=DataType.Null) + T14 = fd.ops.sum(T13, dims=[2], keepdim=False, dtype=DataType.Null) V18 = fd.define_vector([T5.size(0), T5.size(1), 1], dtype=DataType.Int) T19 = fd.ops.broadcast_in_dim(T14, shape=V18, broadcast_dims=[0, 1]) diff --git a/python_benchmarks/test_huggingface_attn_fwd.py b/python_benchmarks/test_huggingface_attn_fwd.py index 6410ee3e21f..2a0b6090cd2 100644 --- a/python_benchmarks/test_huggingface_attn_fwd.py +++ b/python_benchmarks/test_huggingface_attn_fwd.py @@ -36,14 +36,14 @@ def huggingface_attn_fwd_fusion( [T0.size(0) * T0.size(1), T0.size(2), T0.size(3)], dtype=DataType.Int ) T10 = fd.ops.reshape(T4, new_shape=V9) - T12 = fd.ops.max(T10, axes=[2], keepdim=False, dtype=DataType.Null) + T12 = fd.ops.max(T10, dims=[2], keepdim=False, dtype=DataType.Null) V16 = fd.define_vector([T0.size(0) * T0.size(1), T0.size(2), 1], dtype=DataType.Int) T17 = fd.ops.broadcast_in_dim(T12, shape=V16, broadcast_dims=[0, 1]) T22 = fd.ops.broadcast_in_dim(T17, shape=V9, broadcast_dims=[0, 1, 2]) T23 = fd.ops.sub(T10, T22) T24 = fd.ops.exp(T23) - T25 = fd.ops.sum(T24, axes=[2], keepdim=False, dtype=DataType.Null) + T25 = fd.ops.sum(T24, dims=[2], keepdim=False, dtype=DataType.Null) T30 = fd.ops.broadcast_in_dim(T25, shape=V16, broadcast_dims=[0, 1]) T35 = fd.ops.broadcast_in_dim(T30, shape=V9, broadcast_dims=[0, 1, 2]) diff --git a/python_benchmarks/test_layernorm_bwd.py b/python_benchmarks/test_layernorm_bwd.py index 20bb91cff11..91951a0e668 100644 --- a/python_benchmarks/test_layernorm_bwd.py +++ b/python_benchmarks/test_layernorm_bwd.py @@ -41,27 +41,27 @@ def layernorm_bwd_fusion( T19 = fd.ops.mul(T14, T18) T23 = fd.ops.broadcast_in_dim(T4, shape=V12, broadcast_dims=[1]) - T28 = fd.ops.sum(T1, axes=[0], keepdim=False, dtype=DataType.Null) + T28 = fd.ops.sum(T1, dims=[0], keepdim=False, dtype=DataType.Null) T30 = fd.ops.mul(T1, T23) T31 = fd.ops.mul(T1, T19) - T32 = fd.ops.sum(T31, axes=[0], keepdim=False, dtype=DataType.Null) + T32 = fd.ops.sum(T31, dims=[0], keepdim=False, dtype=DataType.Null) T34 = fd.ops.mul(T30, T18) T35 = fd.ops.mul(T30, T14) - T36 = fd.ops.sum(T35, axes=[1], keepdim=False, dtype=DataType.Null) + T36 = fd.ops.sum(T35, dims=[1], keepdim=False, dtype=DataType.Null) T40 = fd.ops.broadcast_in_dim(T36, shape=V8, broadcast_dims=[0]) T41 = fd.ops.neg(T34) - T42 = fd.ops.sum(T41, axes=[1], keepdim=False, dtype=DataType.Null) + T42 = fd.ops.sum(T41, dims=[1], keepdim=False, dtype=DataType.Null) T46 = fd.ops.broadcast_in_dim(T42, shape=V8, broadcast_dims=[0]) S47 = fd.define_scalar(-0.500000, dtype=DataType.Double) T48 = fd.ops.mul(S47, T40) S49 = fd.define_scalar(3.00000, dtype=DataType.Double) T50 = fd.ops.pow(T3, S49) T51 = fd.ops.mul(T48, T50) - T54 = fd.ops.sum(T46, axes=[1], keepdim=False, dtype=DataType.Null) - T55 = fd.ops.sum(T51, axes=[1], keepdim=False, dtype=DataType.Null) + T54 = fd.ops.sum(T46, dims=[1], keepdim=False, dtype=DataType.Null) + T55 = fd.ops.sum(T51, dims=[1], keepdim=False, dtype=DataType.Null) T59 = fd.ops.broadcast_in_dim(T55, shape=V8, broadcast_dims=[0]) T63 = fd.ops.broadcast_in_dim(T59, shape=V12, broadcast_dims=[0, 1]) diff --git a/python_benchmarks/test_layernorm_fwd.py b/python_benchmarks/test_layernorm_fwd.py index ca657b3521a..7393054f819 100644 --- a/python_benchmarks/test_layernorm_fwd.py +++ b/python_benchmarks/test_layernorm_fwd.py @@ -22,7 +22,7 @@ def layernorm_fwd_fusion( T1 = fd.ops.cast(T1, dtype=DataType.Float) T2 = fd.ops.cast(T2, dtype=DataType.Float) - T3, T4 = fd.ops.var_mean(T0, axes=[1], correction=0, keepdim=False) + T3, T4 = fd.ops.var_mean(T0, dims=[1], correction=0, keepdim=False) V6 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) T7 = fd.ops.broadcast_in_dim(T3, shape=V6, broadcast_dims=[0]) diff --git a/python_benchmarks/test_nanogpt_attn_bwd.py b/python_benchmarks/test_nanogpt_attn_bwd.py index 3c22878adaf..a79dd752e49 100644 --- a/python_benchmarks/test_nanogpt_attn_bwd.py +++ b/python_benchmarks/test_nanogpt_attn_bwd.py @@ -47,7 +47,7 @@ def nanogpt_attn_bwd_fusion( T7 = fd.ops.mul(T2, S0) T8 = fd.ops.mul(T7, T4) T9 = fd.ops.mul(T3, T8) - T10 = fd.ops.sum(T9, axes=[3], keepdim=False, dtype=DataType.Null) + T10 = fd.ops.sum(T9, dims=[3], keepdim=False, dtype=DataType.Null) V15 = fd.define_vector([T2.size(0), T2.size(1), T2.size(2), 1], dtype=DataType.Int) T16 = fd.ops.broadcast_in_dim(T10, shape=V15, broadcast_dims=[0, 1, 2]) diff --git a/python_benchmarks/test_nanogpt_attn_fwd.py b/python_benchmarks/test_nanogpt_attn_fwd.py index c0cb1964eaa..4687f20c1aa 100644 --- a/python_benchmarks/test_nanogpt_attn_fwd.py +++ b/python_benchmarks/test_nanogpt_attn_fwd.py @@ -35,13 +35,13 @@ def nanogpt_attn_fwd_fusion( T11 = fd.ops.broadcast_in_dim(T5, shape=V10, broadcast_dims=[0, 1, 2, 3]) S12 = fd.define_scalar(float("-inf"), dtype=DataType.Double) T13 = fd.ops.where(T11, S12, T3) - T14 = fd.ops.max(T13, axes=[3], keepdim=False, dtype=DataType.Null) + T14 = fd.ops.max(T13, dims=[3], keepdim=False, dtype=DataType.Null) V19 = fd.define_vector([T0.size(0), T0.size(1), T0.size(2), 1], dtype=DataType.Int) T20 = fd.ops.broadcast_in_dim(T14, shape=V19, broadcast_dims=[0, 1, 2]) T26 = fd.ops.broadcast_in_dim(T20, shape=V10, broadcast_dims=[0, 1, 2, 3]) T27 = fd.ops.sub(T13, T26) T28 = fd.ops.exp(T27) - T29 = fd.ops.sum(T28, axes=[3], keepdim=False, dtype=DataType.Null) + T29 = fd.ops.sum(T28, dims=[3], keepdim=False, dtype=DataType.Null) T35 = fd.ops.broadcast_in_dim(T29, shape=V19, broadcast_dims=[0, 1, 2]) T41 = fd.ops.broadcast_in_dim(T35, shape=V10, broadcast_dims=[0, 1, 2, 3]) T42 = fd.ops.reciprocal(T41) diff --git a/python_benchmarks/test_reduction.py b/python_benchmarks/test_reduction.py index 8db230d81bd..b54426c3420 100644 --- a/python_benchmarks/test_reduction.py +++ b/python_benchmarks/test_reduction.py @@ -16,7 +16,7 @@ def reduction_fusion( ) if dtype in PROMOTE_DTYPES: T0 = fd.ops.cast(T0, dtype=DataType.Float) - T2 = fd.ops.sum(T0, axes=[reduction_axis], keepdim=False) + T2 = fd.ops.sum(T0, dims=[reduction_axis], keepdim=False) if dtype in PROMOTE_DTYPES: T2 = fd.ops.cast(T2, dtype=dtype) fd.add_output(T2) diff --git a/python_benchmarks/test_rmsnorm_bwd.py b/python_benchmarks/test_rmsnorm_bwd.py index 5b5f10b7775..71026c3674d 100644 --- a/python_benchmarks/test_rmsnorm_bwd.py +++ b/python_benchmarks/test_rmsnorm_bwd.py @@ -35,7 +35,7 @@ def rmsnorm_bwd_fusion( T23 = fd.ops.mul(T6, T16) T24 = fd.ops.mul(T6, T20) - T25 = fd.ops.sum(T23, axes=[0], keepdim=False, dtype=DataType.Null) + T25 = fd.ops.sum(T23, dims=[0], keepdim=False, dtype=DataType.Null) T28 = fd.ops.mul(T24, T15) T29 = fd.ops.neg(T24) @@ -43,7 +43,7 @@ def rmsnorm_bwd_fusion( T32 = fd.ops.pow(T14, S0) T33 = fd.ops.reciprocal(T32) T34 = fd.ops.mul(T30, T33) - T35 = fd.ops.sum(T34, axes=[1], keepdim=False, dtype=DataType.Null) + T35 = fd.ops.sum(T34, dims=[1], keepdim=False, dtype=DataType.Null) V39 = fd.define_vector([T4.size(0), 1], dtype=DataType.Int) T41 = fd.ops.broadcast_in_dim(T35, shape=V39, broadcast_dims=[0]) T43 = fd.ops.mul(S0, T5) @@ -51,7 +51,7 @@ def rmsnorm_bwd_fusion( T45 = fd.ops.mul(T41, T44) S48 = fd.ops.reciprocal(T4.size(1)) T49 = fd.ops.mul(T45, S48) - T50 = fd.ops.sum(T49, axes=[1], keepdim=False, dtype=DataType.Null) + T50 = fd.ops.sum(T49, dims=[1], keepdim=False, dtype=DataType.Null) T54 = fd.ops.broadcast_in_dim(T50, shape=V39, broadcast_dims=[0]) T58 = fd.ops.broadcast_in_dim(T54, shape=T4.shape(), broadcast_dims=[0, 1]) T59 = fd.ops.mul(T58, S0) diff --git a/python_benchmarks/test_rmsnorm_fwd.py b/python_benchmarks/test_rmsnorm_fwd.py index 6b4116298ba..651a68f5089 100644 --- a/python_benchmarks/test_rmsnorm_fwd.py +++ b/python_benchmarks/test_rmsnorm_fwd.py @@ -20,7 +20,7 @@ def rmsnorm_fwd_fusion( T1 = fd.ops.cast(T1, dtype=DataType.Float) S3 = fd.define_scalar(2.00000, dtype=DataType.Double) T4 = fd.ops.pow(T0, S3) - T5 = fd.ops.sum(T4, axes=[1], keepdim=False, dtype=DataType.Null) + T5 = fd.ops.sum(T4, dims=[1], keepdim=False, dtype=DataType.Null) V8 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) T9 = fd.ops.broadcast_in_dim(T5, shape=V8, broadcast_dims=[0]) S11 = fd.ops.reciprocal(T0.size(1)) diff --git a/python_benchmarks/test_rope.py b/python_benchmarks/test_rope.py new file mode 100644 index 00000000000..50797fb1f8d --- /dev/null +++ b/python_benchmarks/test_rope.py @@ -0,0 +1,184 @@ +import pytest +from nvfuser import FusionDefinition, DataType +from .core import run_benchmark, clear_cuda_cache +import torch + + +# Mimic the Hugging Face implementation: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L216 +def rope_with_cat_fusion( + fd: FusionDefinition, + batch_size: int, + seq_len: int, + num_heads: int, + features_per_head: int, +) -> None: + q = fd.define_tensor( + shape=[batch_size, seq_len, num_heads, features_per_head], + dtype=DataType.BFloat16, + ) + cos = fd.define_tensor( + shape=[seq_len, features_per_head], + dtype=DataType.BFloat16, + ) + sin = fd.define_tensor( + shape=[seq_len, features_per_head], + dtype=DataType.BFloat16, + ) + + q = fd.ops.permute(q, dims=[0, 2, 1, 3]) + q_real = fd.ops.slice( + q, + start_indices=[0, 0, 0, 0], + end_indices=[batch_size, num_heads, seq_len, features_per_head // 2], + strides=[1, 1, 1, 1], + ) + q_image = fd.ops.slice( + q, + start_indices=[0, 0, 0, features_per_head // 2], + end_indices=[batch_size, num_heads, seq_len, features_per_head], + strides=[1, 1, 1, 1], + ) + + # nvFuser has problems generating negation for bfloat. + q_image = fd.ops.cast(q_image, dtype=DataType.Float) + q_image = -q_image + q_image = fd.ops.cast(q_image, dtype=DataType.BFloat16) + + q_rotated = fd.ops.cat([q_image, q_real], dim=-1) + + cos = fd.ops.broadcast_in_dim( + cos, shape=[1, 1, seq_len, features_per_head], broadcast_dims=[2, 3] + ) + sin = fd.ops.broadcast_in_dim( + sin, shape=[1, 1, seq_len, features_per_head], broadcast_dims=[2, 3] + ) + + out = q * cos + q_rotated * sin + out = fd.ops.cast(out, DataType.BFloat16) + fd.add_output(out) + + +# Idea from @nikitaved: we split and concatenate the embeddings instead of `q`. +# The embeddings are constant that can be precomputed. So we pay the overhead +# of split and concatenation only once. The actual forward pass is merely +# elementwise+reduction surrounded by some meta ops. +def rope_without_cat_fusion( + fd: FusionDefinition, + batch_size: int, # B + seq_len: int, # S + num_heads: int, # H + features_per_head: int, # F +) -> None: + q = fd.define_tensor( + shape=[batch_size, seq_len, num_heads, features_per_head], + dtype=DataType.BFloat16, + ) + # `cos_sin_matrix` is essentially a batch (of size S*F/2) of 2x2 matrices + # laid out in a special way to keep computation simple. + # + # Using the notations in Figure 1 in https://arxiv.org/pdf/2104.09864.pdf, + # cos_sin_matrix[0] contains the following: + # + # cos(θ_1), -sin(θ1) + # cos(θ_2), -sin(θ2) + # ... + # cos(θ_F/2), -sin(θ_F/2) + # ------------------------ + # sin(θ_1), cos(θ_1) + # sin(θ_2), cos(θ_2) + # ... + # sin(θ_F/2), cos(θ_F/2) + # + # cos_sin_matrix[i] is similar but each θ is multiplied by `i+1`. + cos_sin_matrix = fd.define_tensor( + shape=[seq_len, 2, features_per_head // 2, 2], + dtype=DataType.BFloat16, + ) + + q = fd.ops.reshape( + q, new_shape=[batch_size, seq_len, num_heads, 2, features_per_head // 2] + ) + q = fd.ops.permute(q, dims=[0, 2, 1, 4, 3]) + q = fd.ops.broadcast_in_dim( + q, + shape=[batch_size, num_heads, seq_len, 1, features_per_head // 2, 2], + broadcast_dims=[0, 1, 2, 4, 5], + ) + + cos_sin_matrix = fd.ops.broadcast_in_dim( + cos_sin_matrix, + shape=[batch_size, num_heads, seq_len, 2, features_per_head // 2, 2], + broadcast_dims=[2, 3, 4, 5], + ) + + out = fd.ops.sum(q * cos_sin_matrix, [-1]) + out = fd.ops.cast(out, DataType.BFloat16) + out = fd.ops.reshape( + out, new_shape=[batch_size, num_heads, seq_len, features_per_head] + ) + fd.add_output(out) + + +@pytest.mark.parametrize("use_cat", [True, False], ids=["with_cat", "without_cat"]) +def test_rope_benchmark( + benchmark, use_cat: bool, disable_validation: bool, disable_benchmarking: bool +): + clear_cuda_cache() + + batch_size = 32 + seq_len = 4096 + num_heads = 32 + features_per_head = 128 + + # torch.manual_seed(0) + q = torch.randn( + batch_size, + seq_len, + num_heads, + features_per_head, + dtype=torch.bfloat16, + device="cuda:0", + ) + freqs = torch.randn( + seq_len, features_per_head // 2, dtype=torch.bfloat16, device="cuda:0" + ) + cos = freqs.cos() + sin = freqs.sin() + + if use_cat: + with FusionDefinition() as fd: + rope_with_cat_fusion(fd, batch_size, seq_len, num_heads, features_per_head) + inputs = [q, torch.cat([cos, cos], dim=-1), torch.cat([sin, sin], dim=-1)] + else: + with FusionDefinition() as fd: + rope_without_cat_fusion( + fd, batch_size, seq_len, num_heads, features_per_head + ) + # [S, F/2, 2] + cos_and_minus_sin = torch.stack([cos, -sin], dim=-1) + # [S, F/2, 2] + sin_and_cos = torch.stack([sin, cos], dim=-1) + # [S, 2, F/2, 2] + cos_sin_matrix = torch.stack([cos_and_minus_sin, sin_and_cos], dim=1) + inputs = [q, cos_sin_matrix] + + if not disable_validation: + q_real, q_image = q.permute([0, 2, 1, 3]).split(features_per_head // 2, dim=-1) + q_real = q_real.to(torch.float32) + q_image = q_image.to(torch.float32) + ref_out = torch.cat( + [q_real * cos - q_image * sin, q_image * cos + q_real * sin], dim=-1 + ).to(torch.bfloat16) + nvf_out = fd.execute(inputs) + if use_cat: + # For unknown reasons, rope_with_cat_fusion produces slightly off + # numbers. It looks like a problem with cast-to-bfloat, because when + # I removed the fd.ops.cast in the end of that fusion, the results + # were bit-exact as the reference output. + torch.testing.assert_close(nvf_out, [ref_out], atol=0.01, rtol=0.01) + else: + torch.testing.assert_close(nvf_out, [ref_out], atol=0, rtol=0) + + if not disable_benchmarking: + run_benchmark(benchmark, fd.execute, inputs) diff --git a/python_benchmarks/test_softmax_bwd.py b/python_benchmarks/test_softmax_bwd.py index 8d2d5b6c6bb..52e7d4c3682 100644 --- a/python_benchmarks/test_softmax_bwd.py +++ b/python_benchmarks/test_softmax_bwd.py @@ -27,7 +27,7 @@ def softmax_bwd_fusion( T1 = fd.ops.cast(T1, dtype=DataType.Float) T4 = fd.ops.mul(T0, T1) - T5 = fd.ops.sum(T4, axes=[reduction_axis], keepdim=False, dtype=DataType.Null) + T5 = fd.ops.sum(T4, dims=[reduction_axis], keepdim=False, dtype=DataType.Null) if reduction_axis: V9 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) @@ -54,10 +54,15 @@ def softmax_bwd_fusion( fd.add_output(T19) +def unary_bwd_torch(inputs: list): # [in_tensor, output, grads] + inputs[1].backward(inputs[2], retain_graph=True) + return inputs[0].grad + + @pytest.mark.parametrize("size", generate_input_sizes(dims=2)) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("reduction_axis", [0, 1]) -def test_softmax_bwd_benchmark( +def test_softmax_bwd_nvf_benchmark( benchmark, size: tuple, dtype: torch.dtype, @@ -68,8 +73,8 @@ def test_softmax_bwd_benchmark( clear_cuda_cache() inputs = [ - torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True), - torch.randn(*size, device="cuda", dtype=dtype), + torch.randn(size, device="cuda", dtype=dtype, requires_grad=True), + torch.randn(size, device="cuda", dtype=dtype), ] with FusionDefinition() as fd: @@ -82,3 +87,25 @@ def test_softmax_bwd_benchmark( if not disable_benchmarking: run_benchmark(benchmark, fd.execute, inputs) + + +@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"]) +@pytest.mark.parametrize("size", generate_input_sizes(dims=2)) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("reduction_axis", [0, 1]) +def test_softmax_bwd_baseline_benchmark( + benchmark, + size: tuple, + dtype: torch.dtype, + reduction_axis: int, + compile: bool, +): + clear_cuda_cache() + input = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) + grads = torch.randn(size, device="cuda", dtype=dtype) + output = torch.nn.functional.softmax(input, dim=reduction_axis) + run_benchmark( + benchmark, + torch.compile(unary_bwd_torch) if compile else unary_bwd_torch, + [input, output, grads], + ) diff --git a/python_benchmarks/test_softmax_fwd.py b/python_benchmarks/test_softmax_fwd.py index cb48f2bee30..7d7e3239601 100644 --- a/python_benchmarks/test_softmax_fwd.py +++ b/python_benchmarks/test_softmax_fwd.py @@ -17,7 +17,7 @@ def softmax_fwd_fusion( ) if dtype in PROMOTE_DTYPES: T0 = fd.ops.cast(T0, dtype=DataType.Float) - T2 = fd.ops.max(T0, axes=[reduction_axis], keepdim=False, dtype=DataType.Null) + T2 = fd.ops.max(T0, dims=[reduction_axis], keepdim=False, dtype=DataType.Null) if reduction_axis: V6 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) @@ -31,7 +31,7 @@ def softmax_fwd_fusion( T12 = fd.ops.broadcast_in_dim(T7, shape=V11, broadcast_dims=[0, 1]) T13 = fd.ops.sub(T0, T12) T14 = fd.ops.exp(T13) - T15 = fd.ops.sum(T14, axes=[reduction_axis], keepdim=False, dtype=DataType.Null) + T15 = fd.ops.sum(T14, dims=[reduction_axis], keepdim=False, dtype=DataType.Null) T20 = fd.ops.broadcast_in_dim(T15, shape=V6, broadcast_dims=[bcast_dim]) T25 = fd.ops.broadcast_in_dim(T20, shape=V11, broadcast_dims=[0, 1]) @@ -44,10 +44,14 @@ def softmax_fwd_fusion( fd.add_output(T27) +def softmax_fwd_fn(inputs: list): # [in_tensor, reduction_axis] + return torch.nn.functional.softmax(inputs[0], inputs[1]) + + @pytest.mark.parametrize("size", generate_input_sizes(dims=2)) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("reduction_axis", [0, 1]) -def test_softmax_fwd_benchmark( +def test_softmax_fwd_nvf_benchmark( benchmark, size: tuple, dtype: torch.dtype, @@ -57,14 +61,34 @@ def test_softmax_fwd_benchmark( ): clear_cuda_cache() - inputs = [torch.randn(*size, device="cuda", dtype=dtype)] + inputs = [torch.randn(size, device="cuda", dtype=dtype)] with FusionDefinition() as fd: softmax_fwd_fusion(fd, torch_dtype_to_nvfuser_dtype(dtype), reduction_axis) if not disable_validation: - eager_output = torch.nn.functional.softmax(inputs[0], dim=reduction_axis) + eager_output = softmax_fwd_fn([inputs[0], reduction_axis]) fd.validate(inputs, [eager_output]) if not disable_benchmarking: run_benchmark(benchmark, fd.execute, inputs) + + +@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"]) +@pytest.mark.parametrize("size", generate_input_sizes(dims=2)) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("reduction_axis", [0, 1]) +def test_softmax_fwd_baseline_benchmark( + benchmark, + size: tuple, + dtype: torch.dtype, + reduction_axis: int, + compile: bool, +): + clear_cuda_cache() + input = torch.randn(size, device="cuda", dtype=dtype) + run_benchmark( + benchmark, + torch.compile(softmax_fwd_fn) if compile else softmax_fwd_fn, + [input, reduction_axis], + ) diff --git a/python_tests/pytest_input_generators.py b/python_tests/pytest_input_generators.py index ed544f2f63f..ac282584eea 100644 --- a/python_tests/pytest_input_generators.py +++ b/python_tests/pytest_input_generators.py @@ -26,7 +26,6 @@ MINIMUM_SYMBOLIC_SIZE = -1 INT64_MAX = 2**63 - 1 MAX_TENSOR_DIMS = 8 -MAX_VECTOR_SIZE = 8 # Determine if a number is with desired Domain [low, high) @@ -468,42 +467,11 @@ def define_vector_constant_error_generator( "The value -2 at index 0 was neither symbolic(-1), zero_element(0), broadcast(1), or static(>1)", ) - check_max_vector_size = ErrorSample( - { - "values": [-1 for _ in range(MAX_VECTOR_SIZE + 1)], - }, - "The specified vector size exceeds the max tensor size for nvfuser.", - ) - error_cases = [ # FIXME: The above_size_range case gives a non-sensical error message. # "Unable to cast Python instance to C++ type (#define PYBIND11_DETAILED_ER" # check_above_size_range, check_below_size_range, - check_max_vector_size, - ] - - for es in error_cases: - yield SampleInput(**es.kwargs), es.ex_type, es.ex_str - - -def define_vector_input_error_generator( - op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs -): - """ - "define_vector", - [](FusionDefinition& self, size_t size) -> Vector { - """ - - check_max_vector_size = ErrorSample( - { - "size": (MAX_VECTOR_SIZE + 1), - }, - "The specified vector size exceeds the max tensor size for nvfuser.", - ) - - error_cases = [ - check_max_vector_size, ] for es in error_cases: @@ -1255,8 +1223,12 @@ def squeeze_generator( ((1, 1, 1), (0, 1, 2)), ((1, 1, 1), (-3, -2, -1)), # No-op test cases - ((5, 5, 5), (0, 1, 2)), - ((5, 5, 5), (-3, -2, -1)), + # NOTE: These are skipped. We diverge from PyTorch behavior for squeeze + # in nvFuser. Our squeeze op will throw an exception if we pass a + # squeeze dimension that cannot be squeezed. + # See https://github.com/NVIDIA/Fuser/pull/1717 + # ((5, 5, 5), (0, 1, 2)), + # ((5, 5, 5), (-3, -2, -1)), ((), ()), ) diff --git a/python_tests/pytest_opinfos.py b/python_tests/pytest_opinfos.py index 53d6f0f80b3..e52eb7fc183 100644 --- a/python_tests/pytest_opinfos.py +++ b/python_tests/pytest_opinfos.py @@ -22,7 +22,6 @@ define_tensor_generator, define_tensor_error_generator, define_vector_constant_error_generator, - define_vector_input_error_generator, elementwise_binary_generator, _elementwise_binary_torch, elementwise_unary_generator, @@ -90,15 +89,6 @@ ) fusion_input_ops.append(define_vector_constant_opinfo) -define_vector_input_opinfo = OpInfo( - lambda fd: fd.define_vector, - "define_vector_input", - sample_input_generator=None, - error_input_generator=define_vector_input_error_generator, - fd_error_input_fn=api_test_fd_fn, -) -fusion_input_ops.append(define_vector_input_opinfo) - """ End Fusion Input Operations """ """ Start Unary-Float Operations """ diff --git a/python_tests/test_normalization.py b/python_tests/test_normalization.py index 00bda1b83fd..01a181687b8 100644 --- a/python_tests/test_normalization.py +++ b/python_tests/test_normalization.py @@ -147,6 +147,7 @@ def test_instance_norm( assert_close(m.bias.grad, reference_m.bias.grad) +@unittest.skip("disable failing test, see https://github.com/NVIDIA/Fuser/issues/1728") @unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required") def test_instance_norm_multigpu(): class Model(nn.Module): diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index b6eab3a4232..f0669209e81 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -369,12 +369,12 @@ def nvfuser_fusion( shape=[-1], contiguity=[True], dtype=DataType.Float ) bias = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float) - sum0 = fd.ops.sum(inputs, axes=[normalization_axis], keepdim=keepDim) + sum0 = fd.ops.sum(inputs, dims=[normalization_axis], keepdim=keepDim) norm_const = fd.define_scalar(norm_size) mean = fd.ops.div(sum0, norm_const) diff = fd.ops.sub(inputs, mean) diff_sq = fd.ops.mul(diff, diff) - sum1 = fd.ops.sum(diff_sq, axes=[normalization_axis], keepdim=keepDim) + sum1 = fd.ops.sum(diff_sq, dims=[normalization_axis], keepdim=keepDim) var = fd.ops.div(sum1, norm_const) eps_const = fd.define_scalar(eps) var_eps = fd.ops.add(var, eps_const) @@ -410,7 +410,7 @@ def nvfuser_fusion_var_mean( ) bias = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float) var, mean = fd.ops.var_mean( - inputs, axes=[normalization_axis], correction=0, keepdim=keepDim + inputs, dims=[normalization_axis], correction=0, keepdim=keepDim ) eps_const = fd.define_scalar(eps) var_eps = fd.ops.add(var, eps_const) @@ -491,7 +491,7 @@ def nvfuser_fusion( shape=[-1], contiguity=[True], dtype=DataType.Float ) inputs_sq = fd.ops.mul(inputs, inputs) - sum0 = fd.ops.sum(inputs_sq, axes=[normalization_axis], keepdim=keepDim) + sum0 = fd.ops.sum(inputs_sq, dims=[normalization_axis], keepdim=keepDim) norm_const = fd.define_scalar(norm_size) var = fd.ops.div(sum0, norm_const) eps_const = fd.define_scalar(eps) @@ -519,6 +519,26 @@ def nvfuser_fusion( self.assertEqual(eager_out, nvf_out[0]) + def test_tensor_ndim(self): + shape = [2 for i in range(12)] + new_shape = shape[:9] + new_shape.append(8) + + inputs = [torch.randn(shape, device="cuda"), new_shape] + + def fusion_func(fd: FusionDefinition): + t0 = fd.from_pytorch(inputs[0]) + n_shape = fd.define_vector(10) + + t1 = fd.ops.reshape(t0, n_shape) + t2 = fd.ops.sum(t1, dims=[3]) + + fd.add_output(t2) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out = torch.sum(inputs[0].reshape(new_shape), dim=3) + self.assertEqual(eager_out, nvf_out[0]) + # Testing a scenario where a broadcast requires a symbolic output shape def test_tensor_shape(self): inputs = [ @@ -674,7 +694,7 @@ def test_tensor_shape_with_output_bcast(self): def fusion_func(fd: FusionDefinition): t0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True]) - t1 = fd.ops.sum(t0, axes=[2]) + t1 = fd.ops.sum(t0, dims=[2]) t1_b = fd.ops.broadcast_in_dim(t1, t0.shape(), [0, 1]) fd.add_output(t1_b) @@ -1672,8 +1692,8 @@ def schedule(self): ctx_seg_fusion = FusionDefinition() with ctx_seg_fusion: t0 = ctx_seg_fusion.from_pytorch(inputs[0]) - t1 = ctx_seg_fusion.ops.sum(t0, axis=0) - t2 = ctx_seg_fusion.ops.sum(t0, axis=-1) + t1 = ctx_seg_fusion.ops.sum(t0, dim=0) + t2 = ctx_seg_fusion.ops.sum(t0, dim=-1) ctx_seg_fusion.add_output(t1) ctx_seg_fusion.add_output(t2) @@ -1922,7 +1942,7 @@ def nvfuser_fusion(fd: FusionDefinition, prob) -> None: S10 = fd.define_scalar(-1, dtype=DataType.Int) S11 = fd.define_scalar(4, dtype=DataType.Int) S12 = fd.ops.add(S10, S11) - T13 = fd.ops.max(T9, axes=[3], keepdim=False, dtype=DataType.Null) + T13 = fd.ops.max(T9, dims=[3], keepdim=False, dtype=DataType.Null) T14 = fd.ops.broadcast_in_dim( T13, shape=[16, 16, 128, 1], broadcast_dims=[0, 1, 2] ) @@ -1934,7 +1954,7 @@ def nvfuser_fusion(fd: FusionDefinition, prob) -> None: S18 = fd.define_scalar(-1, dtype=DataType.Int) S19 = fd.define_scalar(4, dtype=DataType.Int) S20 = fd.ops.add(S18, S19) - T21 = fd.ops.sum(T17, axes=[3], keepdim=False, dtype=DataType.Null) + T21 = fd.ops.sum(T17, dims=[3], keepdim=False, dtype=DataType.Null) T22 = fd.ops.broadcast_in_dim( T21, shape=[16, 16, 128, 1], broadcast_dims=[0, 1, 2] ) @@ -2762,7 +2782,7 @@ def fusion_func(fd: FusionDefinition) -> None: end_indices=[12, 128, 25, 32, 1], strides=[1, 1, 1, 1, 1], ) - T89 = fd.ops.sum(T98, axes=[4], keepdim=False, dtype=DataType.Null) + T89 = fd.ops.sum(T98, dims=[4], keepdim=False, dtype=DataType.Null) fd.add_output(T89) nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) @@ -2800,8 +2820,8 @@ def fusion_func(fd: FusionDefinition) -> None: T6 = fd.ops.mul(T2, T5) T7 = fd.ops.cast(T0, dtype=DataType.Float) T8 = fd.ops.mul(T7, T5) - T24 = fd.ops.sum(T6, axes=[1], keepdim=False, dtype=DataType.Null) - T11 = fd.ops.sum(T8, axes=[0], keepdim=False, dtype=DataType.Null) + T24 = fd.ops.sum(T6, dims=[1], keepdim=False, dtype=DataType.Null) + T11 = fd.ops.sum(T8, dims=[0], keepdim=False, dtype=DataType.Null) fd.add_output(T24) fd.add_output(T11) @@ -2835,7 +2855,7 @@ def fusion_func(fd: FusionDefinition) -> None: ) S1 = fd.define_scalar(None, dtype=DataType.Double) T7 = fd.ops.reshape(T0, new_shape=[2, 1, 2]) - T8, T9 = fd.ops.var_mean(T7, axes=[2], correction=0, keepdim=False) + T8, T9 = fd.ops.var_mean(T7, dims=[2], correction=0, keepdim=False) T14 = fd.ops.broadcast_in_dim(T8, shape=[2, 1, 1], broadcast_dims=[0, 1]) T19 = fd.ops.broadcast_in_dim(T9, shape=[2, 1, 1], broadcast_dims=[0, 1]) T20 = fd.ops.add(T14, S1) @@ -2921,9 +2941,9 @@ def fusion_func(fd: FusionDefinition) -> None: T15 = fd.ops.cast( T3, dtype=DataType.Float ) # NOTE that RHS is same, but the result is assigned to different variables - T16 = fd.ops.sum(T15, axes=[0, 1], keepdim=False, dtype=DataType.Null) - T20 = fd.ops.sum(T14, axes=[0, 1], keepdim=False, dtype=DataType.Null) - T31 = fd.ops.sum(T14, axes=[2], keepdim=False, dtype=DataType.Null) + T16 = fd.ops.sum(T15, dims=[0, 1], keepdim=False, dtype=DataType.Null) + T20 = fd.ops.sum(T14, dims=[0, 1], keepdim=False, dtype=DataType.Null) + T31 = fd.ops.sum(T14, dims=[2], keepdim=False, dtype=DataType.Null) fd.add_output(T16) fd.add_output(T20) fd.add_output(T31) @@ -3151,6 +3171,72 @@ def fusion_func(fd: FusionDefinition) -> None: nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) # self.assertEqual(nvf_out[0], t24) + # Test that trivial reshapes whose inputs are reductions are concretized + # properly + # See https://github.com/NVIDIA/Fuser/issues/1691 + def test_issue1691(self): + inputs = [ + torch.randn((12,), dtype=torch.float32, device="cuda:0").as_strided( + (1, 3, 4), (12, 4, 1) + ), + torch.randn((12,), dtype=torch.float32, device="cuda:0").as_strided( + (4, 3), (3, 1) + ), + ] + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[1, -1, -1], + contiguity=[None, True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.ops.sum(T1, dims=[1], keepdim=False, dtype=DataType.Null) # 1D + T3 = fd.ops.sum(T0, dims=[1, 0], keepdim=False, dtype=DataType.Null) # 1D + S4 = fd.define_scalar(4, dtype=DataType.Int) + V5 = fd.define_vector([S4], dtype=DataType.Int) + T6 = fd.ops.reshape(T2, new_shape=V5) + S7 = fd.define_scalar(4, dtype=DataType.Int) + V8 = fd.define_vector([S7], dtype=DataType.Int) + T9 = fd.ops.reshape(T3, new_shape=V8) + T10 = fd.ops.mul(T6, T9) + T11 = fd.ops.sum(T10, dims=[0], keepdim=False, dtype=DataType.Null) + fd.add_output(T11) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + torch_ref = (inputs[0].sum(dim=[0, 1]) * inputs[1].sum(dim=1)).sum(dim=0) + self.assertEqual(nvf_out[0], torch_ref) + + # Test that expanded dimensions can be reduced properly + # See https://github.com/NVIDIA/Fuser/issues/1678 + def test_expanded_reduction(self): + inputs = [torch.tensor(1.0, device="cuda").as_strided((2, 3), (0, 0))] + + for keepdim in [False, True]: + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[None, None], + dtype=DataType.Float, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.ops.sum(T0, dims=[0], keepdim=keepdim, dtype=DataType.Null) + fd.add_output(T1) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + + self.assertEqual(nvf_out[0], inputs[0].sum(dim=0, keepdim=keepdim)) + if __name__ == "__main__": run_tests() diff --git a/python_tests/test_schedule_ops.py b/python_tests/test_schedule_ops.py index f1585ab3dbe..ae37af18b5a 100644 --- a/python_tests/test_schedule_ops.py +++ b/python_tests/test_schedule_ops.py @@ -61,7 +61,7 @@ def check_input_error( def fusion_fn(fd: FusionDefinition): fd.t0 = fd.from_pytorch(inputs[0], static_sizes=True) - fd.t1 = fd.ops.sum(fd.t0, axis=-1) + fd.t1 = fd.ops.sum(fd.t0, dim=-1) fd.add_output(fd.t1) class InputError(FusionDefinition): @@ -85,7 +85,7 @@ def valid_use(self, sched_op_fn: Callable): def fusion_fn(fd: FusionDefinition): fd.t0 = fd.from_pytorch(inputs[0], static_sizes=True) - fd.t1 = fd.ops.sum(fd.t0, axis=-1) + fd.t1 = fd.ops.sum(fd.t0, dim=-1) fd.add_output(fd.t1) class BasicValid(FusionDefinition): diff --git a/runtime/grid_reduction.cu b/runtime/grid_reduction.cu index 76afb0038f9..8a75410d9de 100644 --- a/runtime/grid_reduction.cu +++ b/runtime/grid_reduction.cu @@ -625,19 +625,19 @@ __device__ void gridReduceGroup( // This performs a single reduction step, combining a single element "in" with // a previous value "work". For a serial grid reduction, "work" resides in -// global memory. +// global memory, while "in" and "out" are in registers. // // If the write predicate is false, this function returns early (noop). If the // read predicate is false, "init" is used in place of "in". // // If first_step is false, "work" will be read and reduction_op will be called. // The result will be written back to "work" unless last_step is true. -template +template __device__ void serialReductionStep( - T& out, - T in, + T* out, + T* in, T init, - volatile T& work, + volatile T* work, Func reduction_op, bool first_step, bool last_step, @@ -646,12 +646,24 @@ __device__ void serialReductionStep( if (!write_pred) { return; } - out = read_pred ? in : init; + if (read_pred) { + loadGeneric(out, in); + } else { +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + out[i] = init; + } + } if (!first_step) { - reduction_op(out, work); + T work_reg[vec_size]; + loadGlobalToLocal(work_reg, work); +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + reduction_op(out[i], work_reg[i]); + } } if (!last_step) { - work = out; + loadLocalToGlobal(work, out); } } diff --git a/setup.py b/setup.py index f5c80a1b186..0515d83000f 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,9 @@ # --build-with-ucc # Build nvfuser with UCC support. You may need to specify environment variables of UCC_HOME, UCC_DIR, UCX_HOME, UCX_DIR. # +# --build-without-distributed +# Build nvfuser without multidevice support +# # --debug # Building nvfuser in debug mode # @@ -68,6 +71,7 @@ NO_NINJA = False BUILD_WITH_UCC = False BUILD_WITH_ASAN = False +BUILD_WITHOUT_DISTRIBUTED = False PATCH_NVFUSER = True OVERWRITE_VERSION = False VERSION_TAG = None @@ -100,6 +104,9 @@ if arg == "--build-with-asan": BUILD_WITH_ASAN = True continue + if arg == "--build-without-distributed": + BUILD_WITHOUT_DISTRIBUTED = True + continue if arg == "--debug": BUILD_TYPE = "Debug" continue @@ -283,7 +290,10 @@ def cmake(install_prefix: str = "./nvfuser"): if not os.path.exists(cmake_build_dir): os.makedirs(cmake_build_dir) - from tools.gen_nvfuser_version import get_pytorch_cmake_prefix + from tools.gen_nvfuser_version import ( + get_pytorch_cmake_prefix, + get_pytorch_use_distributed, + ) # this is used to suppress import error. # so we can get the right pytorch prefix for cmake @@ -297,6 +307,8 @@ def cmake(install_prefix: str = "./nvfuser"): logger.setLevel(logger_level) + pytorch_use_distributed = get_pytorch_use_distributed() + # generate cmake directory cmd_str = [ get_cmake_bin(), @@ -304,6 +316,7 @@ def cmake(install_prefix: str = "./nvfuser"): "-DCMAKE_BUILD_TYPE=" + BUILD_TYPE, f"-DCMAKE_INSTALL_PREFIX={install_prefix}", f"-DNVFUSER_CPP_STANDARD={CPP_STANDARD}", + f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", "-B", cmake_build_dir, ] @@ -321,6 +334,8 @@ def cmake(install_prefix: str = "./nvfuser"): cmd_str.append("-DBUILD_NVFUSER_BENCHMARK=ON") if BUILD_WITH_ASAN: cmd_str.append("-DNVFUSER_BUILD_WITH_ASAN=ON") + if BUILD_WITHOUT_DISTRIBUTED: + cmd_str.append("-DNVFUSER_DISTRIBUTED=OFF") cmd_str.append(".") print(f"Configuring CMake with {' '.join(cmd_str)}") diff --git a/test/multidevice.cpp b/test/multidevice.cpp index 2b90d8eafe2..8596f156243 100644 --- a/test/multidevice.cpp +++ b/test/multidevice.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include #include diff --git a/test/multidevice.h b/test/multidevice.h index 0cee504577a..91f09930e2e 100644 --- a/test/multidevice.h +++ b/test/multidevice.h @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #pragma once #include diff --git a/test/test_alias.cpp b/test/test_alias.cpp index 1aeb338d33f..3c095b22493 100644 --- a/test/test_alias.cpp +++ b/test/test_alias.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -22,9 +23,11 @@ namespace nvfuser { using testing::_; +using testing::Contains; using testing::ContainsRegex; using testing::Each; using testing::ElementsAre; +using testing::Field; using testing::IsEmpty; using testing::IsTrue; using testing::Not; @@ -1066,7 +1069,7 @@ TEST_F(AliasTest, InPlaceUpdateAliasAcrossSegments) { TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce) // Note: test alias; - fusion->aliasOutputToInput(tv6, tv0, AliasType::InplaceUpdate); + fusion->aliasOutputToInput(tv6, tv0, AllocationType::InplaceUpdate); // TODO: support output on aliased fusion #1488 // remove tv7 after #1488 // fusion->addOutput(tv6); @@ -1106,4 +1109,41 @@ TEST_F(AliasTest, InPlaceUpdateAliasAcrossSegments) { << "`t0` should have been in-place updated to the same value as `t6`."; } +TEST_F(AliasTest, AliasOnlyKernelsAreNotLaunched) { + ProfilerOptionsGuard option_guard; + ProfilerOptionsGuard::getCurOptions().set(ProfilerOption::Enable); + FusionProfiler::start(); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // The segment between `add_out` and `permute_out` is meta-op only and + // turned into a no-op kernel. + TensorView* in = makeContigConcreteTensor({2, 3}); + TensorView* add_out = add(in, in); + TensorView* permute_out = permute(add_out, {1, 0}); + + fusion->addInput(in); + fusion->addOutput(add_out); + fusion->addOutput(permute_out); + + FusionExecutorCache fec(std::move(fusion)); + auto options = at::dtype(at::kFloat).device(at::kCUDA); + at::Tensor in_tensor = at::randn({2, 3}, options); + fec.runFusionWithInputs({in_tensor}); + + const FusionProfile& profile = FusionProfiler::profile(); + // Expect a kernel launched for one of the two segments but not the + // other. + EXPECT_THAT( + profile.kernel_profiles, + UnorderedElementsAre( + Field(&KernelProfile::name, IsEmpty()), + Field(&KernelProfile::name, Not(IsEmpty())))); + + if (ProfilerState::Running == FusionProfiler::state()) { + FusionProfiler::stop(); + } +} + } // namespace nvfuser diff --git a/test/test_combine_mul_sum.cpp b/test/test_combine_mul_sum.cpp index 99e34d7602e..dd510bc8877 100644 --- a/test/test_combine_mul_sum.cpp +++ b/test/test_combine_mul_sum.cpp @@ -185,7 +185,7 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Schedule) { params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; fe.compileFusion( @@ -215,8 +215,8 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { fusion->addOutput(tv2); ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); auto tref = atMatmul(t0, t1, layout); FusionExecutorCache executor_cache(std::move(fusion)); diff --git a/test/test_evaluator.cpp b/test/test_evaluator.cpp index db1b1b47bef..02b768ab6f6 100644 --- a/test/test_evaluator.cpp +++ b/test/test_evaluator.cpp @@ -673,4 +673,35 @@ TEST_F(ExprEvalTest, SumDiv) { evaluator.evaluate(out); } +TEST_F(ExprEvalTest, MmaOp) { + int64_t m = 2, k = 3, n = 4; + + Fusion fusion; + FusionGuard fg(&fusion); + + // The matmul API will expect inputs in the shape [M,K] x [K,N]. + // This is compatible with PyTorch, + std::vector a_shape{m, k}, b_shape{k, n}, out_shape{m, n}; + + auto tv0 = makeConcreteTensor(a_shape, DataType::Half); + auto tv1 = makeConcreteTensor(b_shape, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv0b = broadcast(tv0, {false, false, true}); // [M, K, 1] + auto tv1b = broadcast(tv1, {true, false, false}); // [1, K, N] + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + at::Tensor in_a = at::ones(a_shape, at::kHalf).cuda(); + at::Tensor in_b = at::ones(b_shape, at::kHalf).cuda(); + at::Tensor out_ref = at::full(out_shape, k, at::kFloat).cuda(); + + ExpressionEvaluator evaluator; + evaluator.bind(tv0, in_a); + evaluator.bind(tv1, in_b); + at::Tensor out = evaluator.evaluate(tv2).as(); + EXPECT_TRUE(at::allclose(out, out_ref)); +} } // namespace nvfuser diff --git a/test/test_external_src.cpp b/test/test_external_src.cpp index 9b7fbf3f016..880c2c82c93 100644 --- a/test/test_external_src.cpp +++ b/test/test_external_src.cpp @@ -117,7 +117,7 @@ TEST_F(ExternalSrcExample, Matmul_CUDA) { int M = 2048, N = 3456, K = 2048; MmaLayout layout = MmaLayout::TN; - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); auto at_output = atMatmul(inputs.first, inputs.second, layout).to(at::kFloat); LaunchParams lp(16, 27, 1, 32, 2, 2); diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 615f3a2c415..69f26977932 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -5210,7 +5210,7 @@ TEST_F(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { args.push(t0); auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), args, segment_options); + SegmentCandidateFinder::segment(fusion.get(), &args, segment_options); NVF_CHECK(segmented_fusion->groups().size() == 2); } @@ -5256,7 +5256,7 @@ TEST_F(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { args.push(scalar); auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), args, segment_options); + SegmentCandidateFinder::segment(fusion.get(), &args, segment_options); NVF_CHECK(segmented_fusion->groups().size() == 2); } @@ -5299,7 +5299,7 @@ TEST_F(NVFuserTest, FusionSegmentMixReduction_CUDA) { args.push(t0); auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), args, segment_options); + SegmentCandidateFinder::segment(fusion.get(), &args, segment_options); NVF_CHECK(segmented_fusion->groups().size() <= 2); } @@ -7844,7 +7844,7 @@ TEST_F(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { for (auto i : c10::irange(5)) { (void)i; // Suppress unused variable warning auto segmented_fusion = - SegmentCandidateFinder::segment(fusion_ptr.get(), args); + SegmentCandidateFinder::segment(fusion_ptr.get(), &args); } } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 5bea913af90..18e06fb4d1f 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -3496,58 +3496,6 @@ TEST_F(NVFuserTest, FusionExpandReduce_CUDA) { testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__); } -// Predicate elimination issue repro: -TEST_F(NVFuserTest, FusionExpandReduce2_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeConcreteTensor({1, 4}); - fusion->addInput(tv0); - - auto tv1 = - expand(tv0, {IrBuilder::create(3L), IrBuilder::create(4L)}); - - auto tv2 = sum(tv1, {0}); - fusion->addOutput(tv2); - - // tv2[r{3}, i{4}] - tv2->split(0, NamedScalar::getParallelDim(ParallelType::TIDy)); - tv2->axis(1)->parallelize(ParallelType::TIDy); - tv2->split(0, NamedScalar::getParallelDim(ParallelType::BIDy), false); - tv2->axis(0)->parallelize(ParallelType::BIDy); - tv2->split(-1, NamedScalar::getParallelDim(ParallelType::TIDx)); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-2)->parallelize(ParallelType::BIDx); - // [rBIDy, rO, rTIDy, iBIDx, iTIDx] - tv2->reorder({{-2, 0}, {-1, 1}, {2, 2}}); - // [iBIDx, iTIDx, rTIDy, rBIDy, rO] - auto tv3 = tv2->rFactor({-1}); - - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - scheduler_utils::parallelizeAllLike(tv3); - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({1, 4}, options); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {t0}, LaunchParams(-1, 2, -1, 4, 2, 1)); - auto cg_outputs = fe.runFusion({t0}, LaunchParams(-1, 2, -1, 4, 2, 1)); - - auto ref = t0.expand({3, 4}).sum({0}); - - testValidate( - fusion.get(), - cg_outputs, - {t0}, - {ref}, - __LINE__, - __FILE__, - "", - LaunchParams(-1, 2, -1, 4, 2, 1)); -} - TEST_F(NVFuserTest, FusionVectorComponentReduce_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -8716,6 +8664,80 @@ TEST_F(NVFuserTest, Reduction3DConstantIterationDomain) { executor_cache.fusion(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); } +// don't cache if the input tv is used by slice. +// https://github.com/NVIDIA/Fuser/issues/1697 +TEST_F(NVFuserTest, AvoidCachingSliceInput) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // values to trigger the original bug. + const int64_t eight = 8; + const int64_t twenty = 20; + const int64_t fiveTwelve = 512; + const int64_t batch_size = 128; + const int64_t hidden_size = 4096; + DataType input_dtype = DataType::Half; + auto tv0 = makeContigTensor(2, input_dtype); + auto tv1 = makeContigTensor(1, input_dtype); + fusion->addInput(tv0); + fusion->addInput(tv1); + + // inner persistent + auto tv2 = castOp(DataType::Float, tv0); + auto tv3 = exp(tv2); + auto tv4 = sum(tv3, {-1}); + auto tv5 = broadcast(tv4, {false, true}); + auto tv6 = div(tv3, tv5); + + // reshape t1 to [8, 512*20] + auto val_8 = IrBuilder::create(eight, DataType::Index); + auto val_512x20 = + IrBuilder::create(fiveTwelve * twenty, DataType::Index); + auto tv7 = reshape(tv1, {val_8, val_512x20}); + + // slice-1 reshape to hidden size + auto val_4096 = IrBuilder::create(hidden_size, DataType::Index); + auto tv8 = slice(tv7, {0, 0}, {eight, fiveTwelve}); + auto tv9 = reshape(tv8, {val_4096}); + auto tv10 = broadcast(tv9, {true, false}); + auto tv11 = castOp(DataType::Float, tv10); + fusion->addOutput(tv11); + + // slice-2 reshape to hidden size and link with inner persistent + auto tv12 = slice(tv7, {0, fiveTwelve * 3}, {eight, fiveTwelve * 4}); + auto tv13 = reshape(tv12, {val_4096}); + auto tv14 = broadcast(tv13, {true, false}); + auto tv15 = castOp(DataType::Float, tv14); + auto tv16 = mul(tv6, tv15); + fusion->addOutput(tv16); + + auto options = at::TensorOptions() + .dtype(data_type_to_aten(input_dtype)) + .device(at::kCUDA, 0); + auto t0 = at::randn({batch_size, hidden_size}, options); + auto t1 = at::randn({eight * fiveTwelve * twenty}, options); + std::vector inputs{t0, t1}; + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(inputs); + + // check segment and sliced tvs are not cached + auto kernel_runtime = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(kernel_runtime->isSegmented(), "segmentation didn't happen"); + const auto num_segments = kernel_runtime->fusionSegments()->groups().size(); + NVF_CHECK(num_segments == 3, "Expect 3 segments, got: ", num_segments); + for (const auto& fe : kernel_runtime->executors()) { + for (auto expr : fe.kernel()->exprs()) { + if (expr->isA()) { + auto slice = expr->as(); + NVF_CHECK( + slice->in()->getMemoryType() == MemoryType::Global, + "slice input must be in global memory, get: ", + slice->in()->getMemoryType()); + } + } + } +} // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/test/test_gpu_fused_reduction.cpp b/test/test_gpu_fused_reduction.cpp index a06d6d54f49..476da1981ba 100644 --- a/test/test_gpu_fused_reduction.cpp +++ b/test/test_gpu_fused_reduction.cpp @@ -2559,4 +2559,29 @@ TEST_F(NVFuserTest, FusionCrossEntropyGatherPattern_CUDA) { testValidate(&fusion, cg_outputs, inputs, {ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionTensorRankLimit) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + std::vector input_shape; + for (__attribute__((unused)) auto i : c10::irange(12)) { + input_shape.push_back(3); + } + + auto tv0 = makeSymbolicTensor(input_shape.size()); + fusion->addInput(tv0); + auto tv1 = sum(tv0, {3}); + fusion->addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn(input_shape, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + testValidate( + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 07b1c29ee2d..7b1af53c8f2 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -87,7 +87,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -103,6 +103,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } @@ -137,7 +145,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBFloat16_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout, at::kBFloat16); + auto inputs = matmulAtInput2D(M, N, K, layout, at::kBFloat16); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -153,6 +161,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBFloat16_CUDA) { auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } @@ -191,7 +207,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { params.double_buffer_options.smem_double_buffer_stage = stage; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -207,6 +223,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + ASSERT_EQ(estimated_smem, actual_smem); } } } @@ -259,7 +283,7 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; fe.setMeasureKernelTimeFlag(true); @@ -372,7 +396,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { params.double_buffer_options.double_buffer_smem_read = true; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -388,6 +412,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } } @@ -1044,7 +1076,7 @@ TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -1054,6 +1086,14 @@ TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } @@ -1719,7 +1759,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -1735,6 +1775,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } @@ -1766,7 +1814,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { params.tile_sizes = gemm_tile; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -1782,6 +1830,14 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } @@ -1818,14 +1874,18 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, - {DataType::Half, DataType::Half, DataType::Float}); + data_types, + true, + true); scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -1836,7 +1896,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { {inputs.first, inputs.second}, LaunchParams(), matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + EXPECT_TRUE(getBankConflictInfo(fe.kernel()).empty()); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -1848,6 +1908,11 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { mn_size, " ", k_size); + // Check that computed smem matches actually allocated smem + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } } @@ -1886,16 +1951,18 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; std::tie( params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, - {DataType::Half, DataType::Half, DataType::Float}); + data_types); scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -1913,6 +1980,11 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + // Check that computed smem matches actually allocated smem + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } } @@ -1950,14 +2022,16 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, - {DataType::Half, DataType::Half, DataType::Float}); + data_types); scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -1973,6 +2047,11 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); + // Check that computed smem matches actually allocated smem + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } } @@ -2008,7 +2087,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2059,8 +2138,10 @@ TEST_F(NVFuserTest, FusionAmpereSplitKLikeStridedBatchedMatmul_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 4; scheduleMatmul(&fusion, params); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + auto t0 = + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + auto t1 = + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2110,11 +2191,13 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, - {DataType::Half, DataType::Half, DataType::Float}, + data_types, ignore_occupancy_drop); scheduleMatmul(&fusion, params); @@ -2136,7 +2219,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { num_shared_mem_tensors); at::manual_seed(0); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2158,6 +2241,11 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogue_CUDA) { cg_outputs[0].allclose(tref, 0.01, 0.01), "Result validation failed. Max diff: ", (cg_outputs[0] - tref).abs().max()); + // Check that computed smem matches actually allocated smem + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); if (!params.use_smem_epilogue) { GTEST_SKIP() @@ -2238,11 +2326,13 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueCast_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 4; + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, - {DataType::Half, DataType::Half, DataType::Float}, + data_types, ignore_occupancy_drop); scheduleMatmul(&fusion, params); @@ -2264,7 +2354,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueCast_CUDA) { num_shared_mem_tensors); at::manual_seed(0); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2287,6 +2377,12 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueCast_CUDA) { "Result validation failed. Max diff: ", (cg_outputs[0] - tref).abs().max()); + // Check that computed smem matches actually allocated smem + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); + if (!params.use_smem_epilogue) { GTEST_SKIP() << "Test conducted without utilizing shared memory epilogue due to the device's constrained shared memory capacity."; @@ -2325,11 +2421,13 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueRelu_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 4; + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = mma_utils::generateSharedMemoryEpilogueHeuristics( gemm_tile, params.double_buffer_options.smem_double_buffer_stage, - {DataType::Half, DataType::Half, DataType::Float}, + data_types, ignore_occupancy_drop); scheduleMatmul(&fusion, params); @@ -2351,7 +2449,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueRelu_CUDA) { num_shared_mem_tensors); at::manual_seed(0); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2375,6 +2473,12 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSmemEpilogueRelu_CUDA) { "Result validation failed. Max diff: ", (cg_outputs[0] - tref).abs().max()); + // Check that computed smem matches actually allocated smem + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); + if (!params.use_smem_epilogue) { GTEST_SKIP() << "Test conducted without utilizing shared memory epilogue due to the device's constrained shared memory capacity."; @@ -2416,7 +2520,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSplitK_CUDA) { params.splitk_factor = 2; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( @@ -2428,6 +2532,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSplitK_CUDA) { // Relax tolerance for larger sum due to large K NVF_CHECK(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } @@ -2469,7 +2581,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSplitKBias_CUDA) { params.splitk_factor = 2; scheduleMatmul(&fusion, params); - auto [aten_a, aten_b] = matmulAtInput(M, N, K, layout); + auto [aten_a, aten_b] = matmulAtInput2D(M, N, K, layout); at::Tensor aten_bias = at::randn({M}, aten_a.options()); std::vector inputs = {aten_a, aten_b, aten_bias}; @@ -2484,6 +2596,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSplitKBias_CUDA) { // Relax tolerance for larger sum due to large K NVF_CHECK(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } @@ -2522,9 +2642,9 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBatchSplitK_CUDA) { scheduleMatmul(&fusion, params); at::Tensor aten_a = - matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); at::Tensor aten_b = - matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); std::vector inputs = {aten_a, aten_b}; @@ -2537,6 +2657,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBatchSplitK_CUDA) { // Relax tolerance for larger sum due to large K EXPECT_TRUE(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } @@ -2579,9 +2707,9 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBatchSplitKBias_CUDA) { scheduleMatmul(&fusion, params); at::Tensor aten_a = - matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); at::Tensor aten_b = - matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); at::Tensor aten_bias = at::randn({M}, aten_a.options()); std::vector inputs = {aten_a, aten_b, aten_bias}; @@ -2597,6 +2725,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulBatchSplitKBias_CUDA) { // Relax tolerance for larger sum due to large K EXPECT_TRUE(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); } } diff --git a/test/test_gpu_transpose.cpp b/test/test_gpu_transpose.cpp index c54e2cbba74..9fe61d6c3d7 100644 --- a/test/test_gpu_transpose.cpp +++ b/test/test_gpu_transpose.cpp @@ -1335,4 +1335,57 @@ TEST_F(TransposeTest, TransposeSplitAggregatedVectorizationWidth) { NVF_CHECK(ref.equal(cg_outputs.at(0))); } +// Testing transpose scheduler to handle fusion inputs with reduction IterDomain +// produced by segmented fusion, see issue +// https://github.com/NVIDIA/Fuser/issues/1659 for details +TEST_F(TransposeTest, ReductionIterDomainOnInputsIssue1659) { + auto fusion = std::make_unique(); + auto fusion_ptr = fusion.get(); + FusionGuard fg(fusion_ptr); + + auto tv0 = TensorViewBuilder() + .ndims(3) + .contiguity({true, true, std::nullopt}) + .shape({-1, -1, 1}) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv1 = TensorViewBuilder() + .ndims(3) + .contiguity({true, std::nullopt, true}) + .shape({-1, 1, -1}) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv1); + auto tv2 = sum(tv0, {1}); + auto tv3 = squeeze(tv1, std::vector{1}); + auto tv4 = add(tv2, tv3); + fusion->addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn({1024, 512, 1}, options); + auto t1 = at::randn({1024, 1, 512}, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(runtime->isSegmented(), "Segmentation expected"); + auto heuristic0 = + runtime->schedulerHeuristics()->heuristicsList().at(0).get()->heuristic(); + NVF_CHECK( + heuristic0 == ScheduleHeuristic::Reduction, + "Unexpected heuristic: ", + heuristic0); + auto heuristic1 = + runtime->schedulerHeuristics()->heuristicsList().at(1).get()->heuristic(); + NVF_CHECK( + heuristic1 == ScheduleHeuristic::Transpose, + "Unexpected heuristic: ", + heuristic1); + testValidate(fusion_ptr, cg_outputs, {t0, t1}, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/test/test_id_model.cpp b/test/test_id_model.cpp index fc823316564..d101ca6cabe 100644 --- a/test/test_id_model.cpp +++ b/test/test_id_model.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace nvfuser { @@ -37,4 +38,350 @@ TEST_F(IdModelTest, DetectSelfMapping) { ::testing::HasSubstr("!hasSelfMapping"))); } +namespace { + +// Get n-th parent expr traversing through the first input of each +// parent +Expr* getParentExpr(Val* val, int n) { + for (int i = 0; i < n - 1; ++i) { + NVF_ERROR(val->definition() != nullptr); + val = val->definition()->input(0); + } + NVF_ERROR(val->definition() != nullptr); + return val->definition(); +}; + +TensorView* getTensorByName( + const std::vector& tvs, + StmtNameType name) { + if (auto it = std::find_if( + tvs.begin(), + tvs.end(), + [&](TensorView* tv) { return tv->name() == name; }); + it != tvs.end()) { + return *it; + } else { + return nullptr; + } +} + +// Create a fusion where we're missing a valid concrete id so the compute at map +// processing will fail. We need to be able to create the concrete ID not just +// look for one. It is not yet possible to lower this fusion as the +// current indexing cannot generate correct indices. Also used in +// FusionIndeixing19 as well as Example 2 in the design doc about Loop +// Promotion Analysis. +std::unique_ptr createFusionWithMultipleResolutionPaths() { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({7}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = broadcast(tv1, {false, true}); + + auto tv3 = makeConcreteTensor({7, 11}); + fusion.addInput(tv3); + + auto tv4 = add(tv3, tv2); + auto tv5 = broadcast(tv4, {false, false, true}); + // tv4[7, 11, 1] + + auto tv6 = broadcast(tv1, {false, true}); + + auto tv7 = makeConcreteTensor({7, 13}); + fusion.addInput(tv7); + auto tv8 = add(tv7, tv6); + auto tv9 = broadcast(tv8, {false, true, false}); + // tv9[7, 1, 13] + + auto tv10 = add(tv5, tv9); + fusion.addOutput(tv10); + + // tv10[7, 11, 13] + tv10->merge(0)->merge(0); + // tv10[7*11*13] + tv10->split(0, 5)->split(0, 3); + // tv10[7*11*13//5//3, 3, 5] + + TransformPropagatorWithCheck propagator(tv10); + MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + + std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; + for (auto tensor : tensors_to_inline) { + tensor->inlineAt(1); + } + + return fusion_ptr; +} + +// Check the results of ValGraphStmtSort. Only the ordering of +// ExprGroups is checked for now as it's likely sufficient. +// +// ref_order: The order must be exactly the +// same as indicated by this list. While there can be different +// order that still satisfy the topologial ordering, we also need +// deterministic ordering, so the results should be always the same. +void checkSortingResults( + const ValGraph& graph, + const ExprGroups& sorted_expr_groups, + const ValGroups& sorted_val_groups, + const std::vector& ref_order) { + // Make sure sorted_val_groups cover all Expr groups + const std::unordered_set& ref_expr_group_set{ + graph.disjointExprSets().disjointSets().begin(), + graph.disjointExprSets().disjointSets().end()}; + std::unordered_set sorted_expr_group_set{ + sorted_expr_groups.begin(), sorted_expr_groups.end()}; + ASSERT_EQ(sorted_expr_group_set, ref_expr_group_set) + << "Mismatched ExprGroups."; + + // Make sure sorted_val_groups covers all Val groups + const std::unordered_set& ref_val_group_set{ + graph.disjointValSets().disjointSets().begin(), + graph.disjointValSets().disjointSets().end()}; + std::unordered_set sorted_val_group_set{ + sorted_val_groups.begin(), sorted_val_groups.end()}; + ASSERT_EQ(sorted_val_group_set, ref_val_group_set) << "Mismatched ValGroups."; + + // Check the ordering + ASSERT_EQ(sorted_expr_groups.size(), ref_order.size()); + for (const auto i : c10::irange(ref_order.size())) { + Expr* ref_expr = ref_order.at(i); + const ExprGroup& eg = sorted_expr_groups.at(i); + ASSERT_TRUE(eg->has(ref_expr)) + << "Mismatch detected at " << i << "-th expr group. " + << "Expected: " << nvfuser::toString(graph.toGroup(ref_expr)) << ", " + << ref_expr->toString() << ". Actual: " << nvfuser::toString(eg) << ", " + << eg->front()->toString(); + } +} + +} // namespace + +// Sorting test with a trivial fusion +TEST_F(IdModelTest, ValGraphStmtSort1) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + // No ID expr yet. checkSortingResults validates the exprssion + // order, but since there's no expr, it just makes sure exprs() and + // vals() return all the val and expr groups. + { + IdModel id_model(&fusion); + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {}); + } + + // Add ID exprs. Just apply a merge-and-split pattern to all + // tensors. + tv2->merge(0)->split(0, 4); + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + // The exact graph should just map all IDs of the tensors. Ther + // ordering of the exprs should be the merge and then the split. + { + IdModel id_model(&fusion); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + + // Reference expr order: merge, split + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + + checkSortingResults( + vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); + } +} + +// Sorting test wth a disconnected graph +TEST_F(IdModelTest, ValGraphStmtSort2) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + // Note that the two groups of tensors, {tv0, tv1} and {tv2, tv3}, + // are not connected + + for (auto tv : ir_utils::allTvs(&fusion)) { + tv->merge(0)->split(0, 4); + } + + // Since the two tensors are disconnected, there's no ordering + // between the ID exprs of the two tensor groups. So, the correct + // ordering should have the merge exprs before the split exprs, but + // there's no order between the tv1 and tv3 exprs. For example, + // these are all valid: + // + // tv1 merge -> tv3 merge -> tv1 split -> tv3 split + // tv1 merge -> tv1 split -> tv3 merge -> tv3 split + // tv3 merge -> tv3 split -> tv1 merge -> tv1 split + // tv3 merge -> tv1 merge -> tv3 split -> tv1 split + // + // Here, the actual order returned by ValGraphStmtSort is the first + // one. Since it should be deterministic, we check if the returned + // expr vector is indeed ordered that way. + + IdModel id_model(&fusion); + + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + ValGraphStmtSort vg_stmt_sort(vg); + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); + ref_order.push_back(getParentExpr(tv3->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv3->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + +// Sorting with trivial ExprGroup, i.e., ExprGroup whose input and +// output are mapped as the same ValGroup. It's effectively a cyclic +// dependency and the graph is no longer a DAG. +TEST_F(IdModelTest, ValGraphStmtSort3) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + auto tv3 = makeSymbolicTensor(2); + fusion.addInput(tv3); + auto tv4 = set(tv3); + fusion.addOutput(tv4); + + // Merge and split by one. The split input and output will be mapped. + for (auto tv : {tv0, tv1, tv2}) { + tv->merge(0)->split(0, 1); + } + + // Also test an isolated trivial expr. Note that tv3 and tv4 are not + // connected with tv0, tv1 and tv2. + tv4->merge(0)->split(0, 1); + + IdModel id_model(&fusion); + ValGraph vg = id_model.idGraph(IdMappingMode::EXACT); + + // Map the split-by-1 input and output + vg.mapVals(tv2->axis(0), tv2->axis(0)->definition()->input(0)); + vg.mapVals(tv4->axis(0), tv4->axis(0)->definition()->input(0)); + + ValGraphStmtSort vg_stmt_sort(vg); + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv4->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + ref_order.push_back(getParentExpr(tv4->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + +// Sorting test with the same fusion as Indexing19 +TEST_F(IdModelTest, ValGraphStmtSort4) { + auto fusion = createFusionWithMultipleResolutionPaths(); + FusionGuard fg(fusion.get()); + auto all_tvs = ir_utils::allTvs(fusion.get()); + + // Since this fusion is not supported by ComputeAtMap, the + // validation flag must be false + IdModel id_model(fusion.get(), false, false, false); + id_model.buildExactGraph(); + const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); + + ValGraphStmtSort vg_stmt_sort(vg); + + auto tv1 = getTensorByName(all_tvs, 1); + auto tv2 = getTensorByName(all_tvs, 2); + auto tv4 = getTensorByName(all_tvs, 4); + auto tv5 = getTensorByName(all_tvs, 5); + auto tv6 = getTensorByName(all_tvs, 6); + auto tv8 = getTensorByName(all_tvs, 8); + auto tv9 = getTensorByName(all_tvs, 9); + auto tv10 = getTensorByName(all_tvs, 10); + + // Expected reference order: + // + // exprg{39}: Merge iS2 bS3 + // exprg{57}: Merge iS11 bS12 + // exprg{17}: Merge iS17 bS18 + // exprg{51 63}: Merge iS15 iS16 + // exprg{69 73}: Split iS1 + // exprg{9 25 33 45}: Merge iS20 iS21 + // exprg{41}: Split iS46 + // exprg{59}: Split iS61 + // exprg{19}: Merge iS29 iS19 + // exprg{53 65}: Split iS56 + // exprg{71 75}: Split iS71 + // exprg{11}: Merge iS23 iS22 + // exprg{27}: Merge iS35 bS10 + // exprg{35 47}: Split iS41 + // exprg{43}: Split iS47 + // exprg{61}: Split iS62 + // exprg{21}: Split iS30 + // exprg{55 67}: Split iS57 + // exprg{13}: Split iS24 + // exprg{29}: Split iS36 + // exprg{37 49}: Split iS42 + // exprg{23}: Split iS31 + // exprg{15}: Split iS25 + // exprg{31}: Split iS37 + + std::vector ref_order; + ref_order.push_back(getParentExpr(tv2->axis(0), 3)); + ref_order.push_back(getParentExpr(tv6->axis(0), 3)); + ref_order.push_back(getParentExpr(tv9->axis(0), 4)); + ref_order.push_back(getParentExpr(tv8->axis(0), 3)); + ref_order.push_back(getParentExpr(tv1->axis(0), 2)); + ref_order.push_back(getParentExpr(tv10->axis(0), 4)); + ref_order.push_back(getParentExpr(tv2->axis(0), 2)); + ref_order.push_back(getParentExpr(tv6->axis(0), 2)); + ref_order.push_back(getParentExpr(tv9->axis(0), 3)); + ref_order.push_back(getParentExpr(tv8->axis(0), 2)); + ref_order.push_back(getParentExpr(tv1->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 3)); + ref_order.push_back(getParentExpr(tv5->axis(0), 3)); + ref_order.push_back(getParentExpr(tv4->axis(0), 2)); + ref_order.push_back(getParentExpr(tv2->axis(0), 1)); + ref_order.push_back(getParentExpr(tv6->axis(0), 1)); + ref_order.push_back(getParentExpr(tv9->axis(0), 2)); + ref_order.push_back(getParentExpr(tv8->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 2)); + ref_order.push_back(getParentExpr(tv5->axis(0), 2)); + ref_order.push_back(getParentExpr(tv4->axis(0), 1)); + ref_order.push_back(getParentExpr(tv9->axis(0), 1)); + ref_order.push_back(getParentExpr(tv10->axis(0), 1)); + ref_order.push_back(getParentExpr(tv5->axis(0), 1)); + + checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order); +} + } // namespace nvfuser diff --git a/test/test_linked_hash_map.cpp b/test/test_linked_hash_map.cpp index 8261a358d7c..83ee8696127 100644 --- a/test/test_linked_hash_map.cpp +++ b/test/test_linked_hash_map.cpp @@ -62,6 +62,7 @@ namespace nvfuser { using testing::ElementsAre; using testing::Eq; using testing::Pair; +using testing::Property; TEST(LinkedHashMapTest, PushBack) { LinkedHashMap map; @@ -118,19 +119,15 @@ TEST(LinkedHashMapTest, EraseThenPushBack) { EXPECT_THAT(map, ElementsAre(Pair("a", 1), Pair("b", 4))); } -namespace { -MATCHER_P(DataIs, data, "") { - return arg.data() == data; -} -} // namespace - TEST(LinkedHashMapTest, MovableValue) { LinkedHashMap map; map.pushBack(CopyableKey("a"), MovableValue(1)); map.pushBack(CopyableKey("b"), MovableValue(2)); map.erase(CopyableKey("b")); - EXPECT_THAT(map, ElementsAre(Pair(CopyableKey("a"), DataIs(1)))); + EXPECT_THAT( + map, + ElementsAre(Pair(CopyableKey("a"), Property(&MovableValue::data, 1)))); } } // namespace nvfuser diff --git a/test/test_matmul_sass.cpp b/test/test_matmul_sass.cpp index 27a52288394..c0d21e9b84f 100644 --- a/test/test_matmul_sass.cpp +++ b/test/test_matmul_sass.cpp @@ -78,7 +78,7 @@ sass::Container getSASSFor( params.promote_prologue_smem_reuse = promote_prologue_smem_reuse; scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); FusionExecutor fe; fe.compileFusion( @@ -133,7 +133,7 @@ sass::Container getBinaryOpMulEpilogueSASSFor( scheduleMatmul(&fusion, params); at::manual_seed(0); - auto inputs = matmulAtInput(M, N, K, layout); + auto inputs = matmulAtInput2D(M, N, K, layout); const double alpha = 2.5; FusionExecutor fe; diff --git a/test/test_matmul_scheduler.cpp b/test/test_matmul_scheduler.cpp index f9c8a5a065f..e8b09092fa6 100644 --- a/test/test_matmul_scheduler.cpp +++ b/test/test_matmul_scheduler.cpp @@ -146,9 +146,10 @@ TEST_P(PrecisionParametrizedTest, EpilogueBias) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = + matmulAtInput2D(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = t2.to(at_accu_type); @@ -241,8 +242,8 @@ TEST_P(PrecisionParametrizedTest, EpilogueRelu) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t3 = at::relu(t2); auto t4 = t3.to(at_out_type); @@ -344,9 +345,10 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasRelu) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = + matmulAtInput2D(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = t2.to(at_accu_type); @@ -442,8 +444,8 @@ TEST_P(PrecisionParametrizedTest, EpilogueReluAux) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t3 = t2.to(at_out_type); auto t4 = at::relu(t2); @@ -553,9 +555,10 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasReluAux) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = + matmulAtInput2D(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = t2.to(at_accu_type); @@ -651,8 +654,8 @@ TEST_P(PrecisionParametrizedTest, EpilogueGelu) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t3 = at::gelu(t2); auto t4 = t3.to(at_out_type); @@ -743,8 +746,8 @@ TEST_P(PrecisionParametrizedTest, EpilogueGeluAux) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t3 = t2.to(at_out_type); auto t4 = at::gelu(t2); @@ -848,9 +851,10 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasGelu) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = + matmulAtInput2D(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = t2.to(at_accu_type); @@ -960,9 +964,10 @@ TEST_P(PrecisionParametrizedTest, EpilogueBiasGeluAux) { const int M = 504, N = 136, K = 248; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at_in_type, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at_in_type, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at_in_type, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at_in_type, M, N, K); + auto t2 = + matmulAtInput2D(layout, TensorMatmulPos::Bias, at_out_type, M, N, K); auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = t2.to(at_accu_type); @@ -1038,8 +1043,8 @@ TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) { toString(fusion_layout.getData()), ")"); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); auto tref = atMatmul(t0, t1, layout); FusionExecutorCache executor_cache(std::move(fusion)); @@ -1106,8 +1111,8 @@ TEST_F(MatmulSchedulerTest, BasicMatmulRelaxedCheck) { toString(fusion_layout.getData()), ")"); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); auto tref = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); FusionExecutorCache executor_cache(std::move(fusion)); @@ -1170,8 +1175,8 @@ TEST_F(MatmulSchedulerTest, BasicMatmulInputShuffledTT) { toString(fusion_layout.getData()), ")"); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); auto tref = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); FusionExecutorCache executor_cache(std::move(fusion)); @@ -1238,8 +1243,8 @@ TEST_F(MatmulSchedulerTest, EpilogueOutputCast) { const int M = 504, N = 136, K = 1024; at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto tref = t2.to(at::kHalf); @@ -1302,8 +1307,8 @@ TEST_F(MatmulSchedulerTest, EpilogueAlpha) { at::manual_seed(0); const double alpha = 2.5; - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto tref = at::mul(t2, alpha).to(at::kFloat); @@ -1367,8 +1372,8 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaOutputCast) { at::manual_seed(0); const double alpha = 2.5; - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); auto t2 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t3 = at::mul(t2, alpha).to(at::kFloat); auto tref = t3.to(at::kHalf); @@ -1441,9 +1446,9 @@ TEST_F(MatmulSchedulerTest, EpilogueBeta) { at::manual_seed(0); const double beta = 2.5; - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::C, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t2 = matmulAtInput2D(layout, TensorMatmulPos::C, at::kHalf, M, N, K); auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); @@ -1524,9 +1529,9 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBeta) { at::manual_seed(0); const double alpha = 2.5; const double beta = 1.5; - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::C, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t2 = matmulAtInput2D(layout, TensorMatmulPos::C, at::kHalf, M, N, K); auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = at::mul(t3, alpha).to(at::kFloat); @@ -1612,9 +1617,9 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaGeluOutputCast) { at::manual_seed(0); const double alpha = 2.5; const double beta = 1.5; - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::C, at::kHalf, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t2 = matmulAtInput2D(layout, TensorMatmulPos::C, at::kHalf, M, N, K); auto t3 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = at::mul(t3, alpha).to(at::kFloat); @@ -1704,10 +1709,10 @@ TEST_F(MatmulSchedulerTest, EpilogueAlphaBetaBias) { at::manual_seed(0); const double alpha = 2.5; const double beta = 1.5; - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K); - auto t2 = matmulAtInput(layout, TensorMatmulPos::C, at::kHalf, M, N, K); - auto t3 = matmulAtInput(layout, TensorMatmulPos::Bias, at::kFloat, M, N, K); + auto t0 = matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K); + auto t1 = matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K); + auto t2 = matmulAtInput2D(layout, TensorMatmulPos::C, at::kHalf, M, N, K); + auto t3 = matmulAtInput2D(layout, TensorMatmulPos::Bias, at::kFloat, M, N, K); auto t4 = atMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); // t5 := (A x B) + bias @@ -1783,8 +1788,10 @@ TEST_F(MatmulSchedulerTest, StridedBatch) { FusionExecutorCache executor_cache(std::move(fusion)); at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + auto t0 = + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + auto t1 = + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); auto t2 = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto outputs = executor_cache.runFusionWithInputs({t0, t1}); @@ -1870,9 +1877,12 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaBeta) { const double alpha = 2.5; const double beta = 1.5; - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); - auto t2 = matmulAtInput(layout, TensorMatmulPos::C, at::kFloat, M, N, K, B); + auto t0 = + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + auto t1 = + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + auto t2 = + matmulAtInput2D(layout, TensorMatmulPos::C, at::kFloat, M, N, K, B); auto t3 = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = at::mul(t3, alpha).to(at::kFloat); @@ -1966,9 +1976,11 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueAlphaSingleBeta) { const double alpha = 1.5; const double beta = 2.5; - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); - auto t2 = matmulAtInput(layout, TensorMatmulPos::C, at::kFloat, M, N, K); + auto t0 = + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + auto t1 = + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + auto t2 = matmulAtInput2D(layout, TensorMatmulPos::C, at::kFloat, M, N, K); auto t3 = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = at::mul(t3, alpha).to(at::kFloat); @@ -2048,10 +2060,12 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueBias) { FusionExecutorCache executor_cache(std::move(fusion)); at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + auto t0 = + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + auto t1 = + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); auto t2 = - matmulAtInput(layout, TensorMatmulPos::Bias, at::kFloat, M, N, K, B); + matmulAtInput2D(layout, TensorMatmulPos::Bias, at::kFloat, M, N, K, B); auto t3 = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = atBiasEpilogue(t3, t2).to(at::kFloat); @@ -2126,11 +2140,13 @@ TEST_F(MatmulSchedulerTest, StridedBatchEpilogueSingleBias) { FusionExecutorCache executor_cache(std::move(fusion)); at::manual_seed(0); - auto t0 = matmulAtInput(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); - auto t1 = matmulAtInput(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + auto t0 = + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + auto t1 = + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); // Explicitly make bias tensor a single dim by passing 0 for batch auto t2 = - matmulAtInput(layout, TensorMatmulPos::Bias, at::kFloat, M, N, K, 0); + matmulAtInput2D(layout, TensorMatmulPos::Bias, at::kFloat, M, N, K, 0); auto t3 = splitkLikeAtMatmul(t0.to(at::kFloat), t1.to(at::kFloat), layout); auto t4 = atBiasEpilogue(t3, t2).to(at::kFloat); diff --git a/test/test_mma.cpp b/test/test_mma.cpp index 361eaac9a7d..d177199ced7 100644 --- a/test/test_mma.cpp +++ b/test/test_mma.cpp @@ -62,18 +62,16 @@ void setAsARange(at::Tensor tensor) { } // namespace debugging -using MmaTestParams = std::tuple; +using MmaTestParams = std::tuple; class MmaTest : public NVFuserFixtureParamTest { protected: - MmaLayout layout; MmaMacro macro; PrimDataType dtype; void SetUp() override { macro = std::get<0>(GetParam()); dtype = std::get<1>(GetParam()); - layout = std::get<2>(GetParam()); if (isTuring(macro) && cudaArchGuardShouldSkip(7, 5)) { GTEST_SKIP() << "skipping tests on pre-Turing GPUs"; @@ -91,40 +89,23 @@ TEST_P(MmaTest, SingleTile) { Fusion fusion; FusionGuard fg(&fusion); - bool transpose_a = (layout == MmaLayout::NT || layout == MmaLayout::NN); - bool transpose_b = (layout == MmaLayout::TT || layout == MmaLayout::NT); - - std::vector A_shape{getM(macro), getK(macro)}, - B_shape{getN(macro), getK(macro)}; - - if (transpose_a) { - std::swap(A_shape[0], A_shape[1]); - } + auto shapes = matmulAtInputShape3DTuring( + getM(macro), getN(macro), getK(macro), MmaLayout::TN); - if (transpose_b) { - std::swap(B_shape[0], B_shape[1]); - } - - auto tv0 = makeConcreteTensor(A_shape, dtype); - auto tv1 = makeConcreteTensor(B_shape, dtype); + auto tv0 = makeConcreteTensor(shapes.first, dtype); + auto tv1 = makeConcreteTensor(shapes.second, dtype); fusion.addInput(tv0); fusion.addInput(tv1); - // [M, K] - if (transpose_a) { - tv0 = transpose(tv0, 0, 1); - } + // [M, 1, K] + // Just doing a gmem->register copy + tv0 = set(tv0); - // [N, K] - if (transpose_b) { - tv1 = transpose(tv1, 0, 1); - } - - // [M, N, K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); + // [1, N, K] + // Just doing a gmem->register copy + tv1 = set(tv1); - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + auto tv2 = fusedMultiplySum(tv0, tv1, {2}); fusion.addOutput(tv2); @@ -138,43 +119,44 @@ TEST_P(MmaTest, SingleTile) { auto tv2c = tv2->cacheBefore(); // [M, N, K] -> [N, M, K] - tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(MmaOperand::A); - tv1b->applyMmaSwizzle(MmaOperand::B); + tv0->reorder({{-2, -3}, {-3, -2}}); + tv0->applyMmaSwizzle(MmaOperand::A); + tv1->applyMmaSwizzle(MmaOperand::B); - tv0b->merge(1); - tv0b->merge(1); - tv0b->axis(1)->parallelize(ParallelType::TIDx); - tv1b->merge(1); - tv1b->axis(1)->parallelize(ParallelType::TIDx); + tv0->merge(1); + tv0->merge(1); + tv0->axis(1)->parallelize(ParallelType::TIDx); + tv1->merge(1); + tv1->axis(1)->parallelize(ParallelType::TIDx); tv2c->applyMmaSwizzle(MmaOperand::Accumulator); tv2->applyMmaSwizzle(MmaOperand::Accumulator); - auto inputs = matmulAtInput( - getM(macro), getN(macro), getK(macro), layout, data_type_to_aten(dtype)); + auto inputs = matmulAtInput3DTuring( + getM(macro), + getN(macro), + getK(macro), + MmaLayout::TN, + data_type_to_aten(dtype)); FusionExecutor fe; fe.compileFusion( &fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + inputs.first.squeeze().to(at::kFloat), + inputs.second.squeeze().to(at::kFloat), + MmaLayout::TN); EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } -auto all_mma_layouts = - testing::Values(MmaLayout::TT, MmaLayout::TN, MmaLayout::NT, MmaLayout::NN); - auto all_dtypes = testing::Values(DataType::Half, DataType::BFloat16); std::string testName(const testing::TestParamInfo& info) { std::ostringstream os; auto macro = std::get<0>(info.param); auto dtype = std::get<1>(info.param); - auto layout = std::get<2>(info.param); - os << getM(macro) << "_" << getN(macro) << "_" << getK(macro) << "_" - << toString(layout) << dtype; + os << toString(macro) << dtype; return os.str(); } @@ -186,8 +168,7 @@ INSTANTIATE_TEST_SUITE_P( MmaMacro::Turing_16_8_8, MmaMacro::Turing_16_8_16, MmaMacro::Turing_16_16_16), - testing::Values(DataType::Half), - all_mma_layouts), + testing::Values(DataType::Half)), testName); INSTANTIATE_TEST_SUITE_P( @@ -195,8 +176,7 @@ INSTANTIATE_TEST_SUITE_P( MmaTest, testing::Combine( testing::Values(MmaMacro::Ampere_16_8_16, MmaMacro::Ampere_16_16_16), - all_dtypes, - all_mma_layouts), + all_dtypes), testName); class HopperBase : public NVFuserTest { @@ -219,6 +199,9 @@ void naivelyParallelize(TensorView* tv) { tv->axis(1)->parallelize(ParallelType::TIDx); } +auto all_mma_layouts = + testing::Values(MmaLayout::TT, MmaLayout::TN, MmaLayout::NT, MmaLayout::NN); + auto all_hopper_macros = testing::Values( MmaMacro::Hopper_64_8_16, MmaMacro::Hopper_64_16_16, @@ -284,16 +267,11 @@ TEST_P(HopperRS, SingleTile) { Fusion fusion; FusionGuard fg(&fusion); - bool transpose_a = (layout == MmaLayout::NT || layout == MmaLayout::NN); bool transpose_b = (layout == MmaLayout::TN || layout == MmaLayout::NN); std::vector A_shape{getM(macro), getK(macro)}, B_shape{getK(macro), getN(macro)}; - if (transpose_a) { - std::swap(A_shape[0], A_shape[1]); - } - if (transpose_b) { std::swap(B_shape[0], B_shape[1]); } @@ -303,11 +281,6 @@ TEST_P(HopperRS, SingleTile) { fusion.addInput(tv0); fusion.addInput(tv1); - // [M, K] - if (transpose_a) { - tv0 = transpose(tv0, 0, 1); - } - TensorView* tv0b = nullptr; int axes = 0; if (transpose_b) { @@ -360,7 +333,7 @@ TEST_P(HopperRS, SingleTile) { tv2c->applyMmaSwizzle(MmaOperand::Accumulator); tv2->applyMmaSwizzle(MmaOperand::Accumulator); - auto inputs = matmulAtInput( + auto inputs = matmulAtInput2D( getM(macro), getN(macro), getK(macro), layout, data_type_to_aten(dtype)); FusionExecutor fe; @@ -380,8 +353,8 @@ std::string testNameHopperRS( auto dtype = std::get<1>(info.param); auto layout = std::get<2>(info.param); auto swizzle_b = std::get<3>(info.param); - os << getM(macro) << "_" << getN(macro) << "_" << getK(macro) << "_" - << toString(layout) << "_" << toString(swizzle_b) << dtype; + os << toString(macro) << "_" << toString(layout) << "_" << toString(swizzle_b) + << dtype; return os.str(); } @@ -391,7 +364,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Combine( all_hopper_macros, all_dtypes, - all_mma_layouts, + testing::Values(MmaLayout::TT, MmaLayout::TN), all_smem_swizzle_modes), testNameHopperRS); @@ -528,7 +501,7 @@ TEST_P(HopperSS, SingleTile) { tv2c->applyMmaSwizzle(MmaOperand::Accumulator); tv2->applyMmaSwizzle(MmaOperand::Accumulator); - auto inputs = matmulAtInput( + auto inputs = matmulAtInput2D( getM(macro), getN(macro), getK(macro), layout, data_type_to_aten(dtype)); FusionExecutor fe; diff --git a/test/test_multidevice_communications.cpp b/test/test_multidevice_communications.cpp index 4bdc63dcc3f..249b5df3f49 100644 --- a/test/test_multidevice_communications.cpp +++ b/test/test_multidevice_communications.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include diff --git a/test/test_multidevice_pipeline.cpp b/test/test_multidevice_pipeline.cpp index 5c16a4c42c2..cfe600dd417 100644 --- a/test/test_multidevice_pipeline.cpp +++ b/test/test_multidevice_pipeline.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include @@ -45,7 +45,7 @@ using namespace torch::jit::fuser::cuda; using namespace at::indexing; /* To run the following tests on several devices, pytorch must be installed - with the flag USE_DISTRIBUTED=1 and nccl support. + with the flag NVFUSER_DISTRIBUTED=1 and nccl support. Then simply run the tests on several processes, for example using mpirun on a node having at least 6 GPUs, e.g.: mpirun -np 6 build/nvfuser_tests diff --git a/test/test_multidevice_sharding.cpp b/test/test_multidevice_sharding.cpp index d8a4e35f9a1..a87786454b2 100644 --- a/test/test_multidevice_sharding.cpp +++ b/test/test_multidevice_sharding.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#ifdef USE_DISTRIBUTED +#ifdef NVFUSER_DISTRIBUTED #include #include #include diff --git a/test/test_pipeline.cpp b/test/test_pipeline.cpp index 9cc6e6dcf3a..77bcddfe0f2 100644 --- a/test/test_pipeline.cpp +++ b/test/test_pipeline.cpp @@ -208,9 +208,9 @@ class automaticReshardingTest .only_segment_resharding_exprs = true}; auto segmented_fusion = - SegmentCandidateFinder::segment(std::move(fusion), options); + SegmentCandidateFinder::segment(std::move(fusion), nullptr, options); - for (auto group : segmented_fusion->groups()) { + for (SegmentedGroup* group : segmented_fusion->groups()) { GTEST_EXPECT_TRUE( std::none_of( group->exprs().begin(), diff --git a/test/test_pointwise.cpp b/test/test_pointwise.cpp index 7fd5ec805f8..aeb8a0c2650 100644 --- a/test/test_pointwise.cpp +++ b/test/test_pointwise.cpp @@ -4,15 +4,16 @@ * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ +// clang-format on #include #include -#include -#include #include +#include +#include +#include #include #include -#include namespace nvfuser { @@ -33,6 +34,22 @@ size_t getVecSizeForPointwise(FusionExecutorCache& fec) { return 1; } +bool hasVectorizationCache(TensorView* tv) { + NVF_CHECK(tv->isFusionInput()); + NVF_CHECK(tv->uses().size() == 1); + auto set_expr = dynamic_cast(tv->uses().at(0)); + NVF_CHECK(set_expr != nullptr && set_expr->opType() == LoadStoreOpType::Set); + auto cached_input = set_expr->out()->as(); + NVF_CHECK(cached_input, "expects input to be cached"); + + for (const auto* id : cached_input->getLeafDomain()) { + if (id->getParallelType() == ParallelType::Vectorize) { + return true; + } + } + return false; +} + } // namespace TEST_F(PointwiseTest, VectorizeStrideContiguity2D) { @@ -195,10 +212,219 @@ TEST_F(PointwiseTest, VectorizeAllocationDomain) { fec.profile(true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::empty_strided({1024, 128, 25}, {128*25, 1, 128}, options); + at::Tensor t0 = + at::empty_strided({1024, 128, 25}, {128 * 25, 1, 128}, options); auto cg_outputs = fec.runFusionWithInputs({t0}); EXPECT_EQ(getVecSizeForPointwise(fec), 4); testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); } +// All inputs & outputs share the same allocation domain permutation from root +// domain, but intermediate tv2 isn't specified a stride order. There's also a +// broadcast IterDomain on tv1, which is tricky for vectorization analysis to +// figure out which axes should be excluded from the computation of +// vectorization factor. +TEST_F(PointwiseTest, Issue1567VectorizeAllocationDomain) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(3) + .contiguity({true, true, true}) + .strideOrder({2, 0, 1}) + .build(); + TensorView* tv1 = TensorViewBuilder() + .ndims(3) + .shape({1, -1, 1}) + .contiguity({std::nullopt, std::nullopt, true}) + .strideOrder({2, 0, 1}) + .build(); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0, DataType::Float)); + tv3->setAllocationDomain({tv3->axis(0), tv3->axis(2), tv3->axis(1)}, true); + fusion->addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = + at::empty_strided({1024, 128, 25}, {128 * 25, 1, 128}, options); + at::Tensor t1 = at::empty_strided({1, 128, 1}, {128, 1, 128}, options); + std::vector aten_inputs = {t0, t1}; + + // NOTE: force pointwise scheduler here just for testing purpose + auto params = getPointwiseHeuristics(fusion, aten_inputs); + auto lparams = schedulePointwise(fusion, aten_inputs); + FusionExecutor fe; + fe.compileFusion(fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + EXPECT_EQ(params->vectorize, true); + EXPECT_EQ(params->unroll_factor, 4); + EXPECT_TRUE(hasVectorizationCache(tv0)); + EXPECT_TRUE(hasVectorizationCache(tv1)); + + testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase0) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(3) + .contiguity({true, true, std::nullopt}) + .shape({-1, -1, 1}) + .build(); + TensorView* tv1 = + TensorViewBuilder().ndims(3).contiguity({true, true, true}).build(); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1024, 2, 1}, options); + at::Tensor t1 = at::randn({1024, 2, 512}, options); + std::vector aten_inputs = {t0, t1}; + + // NOTE: force pointwise scheduler here just for testing purpose + auto params = getPointwiseHeuristics(fusion, aten_inputs); + auto lparams = schedulePointwise(fusion, aten_inputs); + FusionExecutor fe; + fe.compileFusion(fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + EXPECT_EQ(params->vectorize, true); + EXPECT_EQ(params->unroll_factor, 4); + EXPECT_FALSE(hasVectorizationCache(tv0)); + EXPECT_TRUE(hasVectorizationCache(tv1)); + + testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase1) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(3) + .contiguity({true, std::nullopt, true}) + .shape({-1, 1, -1}) + .build(); + TensorView* tv1 = + TensorViewBuilder().ndims(3).contiguity({true, true, true}).build(); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1024, 1, 2}, options); + at::Tensor t1 = at::randn({1024, 512, 2}, options); + std::vector aten_inputs = {t0, t1}; + + // NOTE: force pointwise scheduler here just for testing purpose + auto params = getPointwiseHeuristics(fusion, aten_inputs); + auto lparams = schedulePointwise(fusion, aten_inputs); + FusionExecutor fe; + fe.compileFusion(fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + EXPECT_EQ(params->vectorize, true); + EXPECT_EQ(params->unroll_factor, 2); + EXPECT_TRUE(hasVectorizationCache(tv0)); + EXPECT_TRUE(hasVectorizationCache(tv1)); + + testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase2) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(3) + .contiguity({true, std::nullopt, true}) + .shape({-1, 1, -1}) + .build(); + TensorView* tv1 = TensorViewBuilder() + .ndims(3) + .contiguity({true, true, true}) + .strideOrder({1, 2, 0}) + .build(); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = transpose(tv2, 0, 1); + fusion->addOutput(tv3); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1024, 1, 2}, options); + at::Tensor t1 = at::empty_strided({1024, 512, 2}, {2, 2048, 1}, options); + std::vector aten_inputs = {t0, t1}; + + // NOTE: force pointwise scheduler here just for testing purpose + auto params = getPointwiseHeuristics(fusion, aten_inputs); + auto lparams = schedulePointwise(fusion, aten_inputs); + FusionExecutor fe; + fe.compileFusion(fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + EXPECT_EQ(params->vectorize, true); + EXPECT_EQ(params->unroll_factor, 4); + EXPECT_TRUE(hasVectorizationCache(tv0)); + EXPECT_TRUE(hasVectorizationCache(tv1)); + + testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +} + +TEST_F(PointwiseTest, VIssue1567ectorizationFactorAnalysisCase3) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(3) + .contiguity({std::nullopt, true, true}) + .shape({1, -1, -1}) + .build(); + TensorView* tv1 = + TensorViewBuilder().ndims(3).contiguity({true, true, true}).build(); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + auto tv3 = transpose(tv2, 0, 1); + fusion->addOutput(tv3); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1, 1024, 2}, options); + at::Tensor t1 = at::randn({512, 1024, 2}, options); + std::vector aten_inputs = {t0, t1}; + + // NOTE: force pointwise scheduler here just for testing purpose + auto params = getPointwiseHeuristics(fusion, aten_inputs); + auto lparams = schedulePointwise(fusion, aten_inputs); + FusionExecutor fe; + fe.compileFusion(fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + EXPECT_EQ(params->vectorize, true); + EXPECT_EQ(params->unroll_factor, 2); + EXPECT_TRUE(hasVectorizationCache(tv0)); + EXPECT_TRUE(hasVectorizationCache(tv1)); + + testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/test/test_serial_gridreduce.cpp b/test/test_serial_gridreduce.cpp index 43675299b68..861e2270293 100644 --- a/test/test_serial_gridreduce.cpp +++ b/test/test_serial_gridreduce.cpp @@ -35,227 +35,6 @@ namespace nvfuser { using SerialGridReductionTest = NVFuserTest; -// Test that we are able to generate code for a serial reduction -// TODO: remove this test once lowering of serial grid reductions is implemented -TEST_F(SerialGridReductionTest, CodegenNodes) { - for (bool serial : {true, false}) { - for (int64_t num_warps : {4, 8}) { - // B is size of inner serial loop. Outer loop is hardcoded at A=4 - // Here we set B to a small value of 8 instead of 32 (i.e. 128 elements - // per thread), so that the non-serial compilation does not take too - // long. - for (int64_t B : {8}) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - int64_t blocks_x = 8; - int64_t blocks_y = 8; - int64_t blocks_z = 5; - int64_t A = 4; // Size of outer serial loop - int64_t H = blocks_z; - int64_t W = A * B * blocks_x * blocks_y * num_warps * 32; - - // Unreduced dimensions should be concrete. Reduced dimension could be - // symbolic, but is concrete here so that we can read tv0 to registers - TensorView* tv0 = TensorViewBuilder() - .shape({H, W}) - .dtype(DataType::Float) - .contiguity(true) - .build(); - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {0}); - fusion->addOutput(tv1); - - // Start with - // [ rS{H}, iS{W} ] - // We are grid reducing the H dimension and we want to coalesce - // accesses in the W dimension. So we first reorder to - // [ iS{W}, rS{H} ] - // then schedule as - // [ iBIDx{blocks_x}, iBIDy{blocks_y}, iS{A}, iS{B}, iTIDy{num_warps}, - // iTIDx{32}, rBIDz{blocks_z} ] - auto tv2 = tv0->cacheAfter(); - auto tv3 = tv1->cacheBefore(); - - tv3->reorder({{1, 0}, {0, 1}}); // blocks_x*blocks_y*A*B*num_warps*32, H - tv3->split(0, 32); // blocks_x*blocks_y*A*B*num_warps, 32, H - tv3->split(0, num_warps); // blocks_x*blocks_y*A*B, num_warps, 32, H - tv3->split(0, B); // blocks_x*blocks_y*A, B, num_warps, 32, H - tv3->split(0, A); // blocks_x*blocks_y, A, B, num_warps, 32, H - tv3->split(0, blocks_y); // blocks_x, blocks_y, A, B, num_warps, 32, H - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(1)->parallelize(ParallelType::BIDy); - tv3->axis(4)->parallelize(ParallelType::TIDy); - tv3->axis(5)->parallelize(ParallelType::TIDx); - tv3->axis(6)->parallelize(ParallelType::BIDz); - // Reorder to put parallel dims first for better inlining - tv3->reorder({ - {4, 2}, - {5, 3}, - {2, 4}, - {3, 5}, - }); - - TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - scheduler_utils::parallelizeAllLike(tv3); - - // Here we just transpose A and B in tv2, so that it will be partially - // inlined with tv3, resulting in a separate loop to load tv0 into - // registers (tv2). - tv2->reorder({ - {-2, -3}, - {-3, -2}, - }); - - inlineMost(); - - FusionExecutor fe; - if (serial) { - fe.registerPostLoweringHook([](kir::Kernel* kernel) { - FusionGuard fg(kernel); - - std::vector& top_level_exprs = - const_cast&>(kernel->topLevelExprs()); - kir::KernelSummary& summary = - const_cast(kernel->summary()); - std::vector& global_allocations = - summary.global_allocations; - // There should be a work buffer and a sync buffer allocated - ASSERT_EQ(global_allocations.size(), 2); - - // Find the position of the last outer loop - size_t top_level_loop_pos = -1; - for (size_t i : c10::irange(top_level_exprs.size())) { - Expr* expr = top_level_exprs.at(i); - if (expr->isA()) { - top_level_loop_pos = i; - } - } - - // This is a poor approximation of a traversal that would appear in - // a lowering pass to both set the isSerial() flag on grid - // reductions and insert wait/release syncs. - // - // tidx_scope is the inner-most fully parallelized scope. It is - // "top-level" in that its loops appear as top-level in the - // generated kernel - kir::Scope& tidx_scope = top_level_exprs.at(top_level_loop_pos) - ->as() - ->body() // BIDx - .at(0) - ->as() - ->body() // BIDy - .at(0) - ->as() - ->body() // TIDy - .at(0) - ->as() - ->body(); // TIDx - kir::Scope& B_scope = tidx_scope.exprs() - .at(5) - ->as() - ->body() // A (reduction loop) - .exprs() - .back() - ->as() - ->body(); // B - // We will need the store op output TensorIndex - LoadStoreOp* output_store_expr = B_scope.exprs() - .back() - ->as() - ->thenBody() - .at(0) - ->as(); - // bidz_scope is the scope containing the GridReduction expression - kir::Scope& bidz_scope = - B_scope.exprs().at(4)->as()->body(); // BIDz - auto old_grop = bidz_scope.at(0)->as(); - // Store the TensorIndex for the output tensor T1_g, so that we can - // re-use its index - auto t1_idx = output_store_expr->output(0)->as(); - - // Create new TensorView and Allocate - auto output = kernel->outputs().at(0)->as(); - Val* i0 = output->getRootDomain().at(0)->extent(); - auto new_work_buf_tv = - TensorViewBuilder().shape(std::vector{i0}).build(); - new_work_buf_tv->setMemoryType(MemoryType::Global); - // associate the index of the output tensor with the work buffer - // NOTE: in actual lowering we would generate an index ourselves - // here, but this works for this test since the T1 store is inlined - // fully with the serial grid reduction. - Val* idx = t1_idx->index(); - - auto new_work_buf_idx = - IrBuilder::create(new_work_buf_tv, idx); - auto new_work_buf_alloc = IrBuilder::create( - new_work_buf_tv, MemoryType::Global, std::vector{i0}); - const kir::Allocate* orig_work_buf_alloc = global_allocations[0]; - global_allocations[0] = new_work_buf_alloc; - // replace work buf alloc expr in top_level_exprs - for (auto i : c10::irange(top_level_exprs.size())) { - if (top_level_exprs[i] == orig_work_buf_alloc) { - top_level_exprs[i] = new_work_buf_alloc; - } - } - // replace work buf in kernel->parameters() - std::vector& params = - const_cast&>(kernel->parameters()); - for (auto i : c10::irange(params.size())) { - if (params[i] == orig_work_buf_alloc->buffer()) { - params[i] = new_work_buf_tv; - } - } - // replace the grid reduction Expr - auto new_grop = IrBuilder::create( - old_grop->getReductionOpType(), - old_grop->init(), - old_grop->out(), - old_grop->in(), - new_work_buf_alloc, - old_grop->sync_buffer(), - old_grop->entrance_index(), - old_grop->entrances(), - old_grop->isAllreduce(), - new_work_buf_idx); - new_grop = new_grop->withPredicate(old_grop->predicate()) - ->as(); - new_grop = new_grop->withWritePredicate(old_grop->writePredicate()) - ->as(); - bidz_scope.at(0) = new_grop; - - auto sync_buf = global_allocations.at(1)->buffer(); - - std::vector& nonpar_top_level_exprs = - const_cast&>(tidx_scope.exprs()); - nonpar_top_level_exprs.insert( - nonpar_top_level_exprs.end() - 2, - IrBuilder::create( - ParallelTypeBitmap(ParallelType::BIDz), sync_buf)); - - nonpar_top_level_exprs.insert( - nonpar_top_level_exprs.end() - 1, - IrBuilder::create( - ParallelTypeBitmap(ParallelType::BIDz), sync_buf)); - }); - } - fe.compileFusion(fusion); - - auto input = at::randn( - {H, W}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0)); - auto outputs = fe.runFusion({input}); - - if (serial) { - testValidate(fusion, outputs, {input}, __LINE__, __FILE__); - } - } - } - } -} - TEST_F(SerialGridReductionTest, Scheduling) { for (bool serial : {true, false}) { for (int64_t num_warps : {4, 8}) { diff --git a/test/test_tensor_factories.cpp b/test/test_tensor_factories.cpp index 7812287a2f4..76e8c27e566 100644 --- a/test/test_tensor_factories.cpp +++ b/test/test_tensor_factories.cpp @@ -410,4 +410,19 @@ TEST_F(TensorFactoryTest, MetadataAsTensor) { testValidate(fusion.get(), cg_outputs, {input0, input1}, __LINE__, __FILE__); } +TEST_F(TensorFactoryTest, NoInputs) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + Val* size = IrBuilder::create(16); + Val* fill_value = IrBuilder::create(1.0); + TensorView* out = full({size}, fill_value, DataType::Float); + fusion->addOutput(out); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto out_tensors = executor_cache.runFusionWithInputs({}); + testValidate(executor_cache.fusion(), out_tensors, {}, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/test/utils.cpp b/test/utils.cpp index c39f899c5b4..5d10bee07d6 100644 --- a/test/utils.cpp +++ b/test/utils.cpp @@ -257,8 +257,8 @@ Container parse(const std::string& nvdisasm_output) { } // namespace sass -// matmulAtInput provides batched inputs in a splitk-like ordering. It provides -// contiguous tensors with these shapes +// matmulAtInput2D provides batched inputs in a splitk-like ordering. It +// provides contiguous tensors with these shapes // TT: [M, B, K] [B, K, N] // TN: [M, B, K] [N, B, K] // NT: [B, K, M] [B, K, N] @@ -388,8 +388,8 @@ TensorView* splitkLikeBatchedMatmul( return tv2; } -// matmulAtInput provides batched inputs in a splitk-like ordering. It provides -// contiguous tensors with these shapes +// matmulAtInput2D provides batched inputs in a splitk-like ordering. It +// provides contiguous tensors with these shapes // TT: [M, B, K] [B, K, N] // TN: [M, B, K] [N, B, K] // NT: [B, K, M] [B, K, N] @@ -451,7 +451,7 @@ at::Tensor splitkLikeAtMatmul(at::Tensor a, at::Tensor b, MmaLayout layout) { return at::Tensor(); } -std::pair matmulAtInput( +std::pair matmulAtInput2D( int M, int N, int K, @@ -478,7 +478,38 @@ std::pair matmulAtInput( return std::make_pair(at::Tensor(), at::Tensor()); } -at::Tensor matmulAtInput( +std::pair, std::vector> matmulAtInputShape3DTuring( + int M, + int N, + int K, + MmaLayout layout) { + switch (layout) { + case MmaLayout::TT: + return {{M, 1, K}, {1, K, N}}; + case MmaLayout::TN: + return {{M, 1, K}, {1, N, K}}; + case MmaLayout::NT: + return {{K, 1, M}, {1, K, N}}; + case MmaLayout::NN: + return {{K, 1, M}, {1, N, K}}; + default: + NVF_CHECK(false, "unsupported data layout."); + } +} + +std::pair matmulAtInput3DTuring( + int M, + int N, + int K, + MmaLayout layout, + c10::ScalarType dtype) { + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + auto shapes = matmulAtInputShape3DTuring(M, N, K, layout); + return std::make_pair( + at::randn(shapes.first, options), at::randn(shapes.second, options)); +} + +at::Tensor matmulAtInput2D( const MmaLayout layout, const TensorMatmulPos tensor, const c10::ScalarType dtype, diff --git a/test/utils.h b/test/utils.h index 1b60c5d8642..6803a183d45 100644 --- a/test/utils.h +++ b/test/utils.h @@ -388,11 +388,13 @@ inline bool maybeClearAllocator(int64_t max_bytes = ((int64_t)1 << 32)) { auto allocator = c10::cuda::CUDACachingAllocator::get(); if (allocator->initialized()) { int device = 0; -#define TORCH_VERSION_GREATER(major, minor, patch) \ - TORCH_VERSION_MAJOR > major || \ - (TORCH_VERSION_MAJOR == major && TORCH_VERSION_MINOR > minor || \ - (TORCH_VERSION_MINOR == minor && TORCH_VERSION_PATCH > patch)) -#if TORCH_VERSION_GREATER(2, 0, 1) +#if NVF_TORCH_VERSION_GREATER(2, 2, 0) + // c10::cuda uses DeviceIndex instead of int + // https://github.com/pytorch/pytorch/pull/119142 + c10::DeviceIndex device_index; + c10::cuda::GetDevice(&device_index); + device = static_cast(device_index); +#elif NVF_TORCH_VERSION_GREATER(2, 0, 1) // GetDevice was introduced in https://github.com/pytorch/pytorch/pull/94864 // in order to properly handle new CUDA 112 behavior c10::cuda::GetDevice(&device); @@ -400,7 +402,7 @@ inline bool maybeClearAllocator(int64_t max_bytes = ((int64_t)1 << 32)) { cudaGetDevice(&device); #endif - auto device_stats = allocator->getDeviceStats(0); + auto device_stats = allocator->getDeviceStats(device); // allocated_bytes[] holds multiple statistics but the first is sum across // both small and large blocks if (uint64_t(device_stats.reserved_bytes[0].current) > @@ -614,7 +616,22 @@ at::Tensor atMatmul(at::Tensor a, at::Tensor b, MmaLayout layout); at::Tensor splitkLikeAtMatmul(at::Tensor a, at::Tensor b, MmaLayout layout); // Utility to generate inputs based on given layout -std::pair matmulAtInput( +std::pair matmulAtInput2D( + int M, + int N, + int K, + MmaLayout layout, + c10::ScalarType dtype = at::kHalf); + +// Utility to generate input shapes based on given layout +std::pair, std::vector> matmulAtInputShape3DTuring( + int M, + int N, + int K, + MmaLayout layout); + +// Utility to generate inputs based on given layout +std::pair matmulAtInput3DTuring( int M, int N, int K, @@ -630,7 +647,7 @@ enum class TensorMatmulPos { A, B, C, D, Bias }; // Utility to generate buffers based on given problem, layout and tensor // position in matmul with support for matmul and strided batch matmul -at::Tensor matmulAtInput( +at::Tensor matmulAtInput2D( const MmaLayout layout, const TensorMatmulPos tensor, const c10::ScalarType dtype, diff --git a/tools/examples/repro.py b/tools/examples/repro.py index c4653e11b5e..1991fedbf96 100644 --- a/tools/examples/repro.py +++ b/tools/examples/repro.py @@ -15,7 +15,7 @@ T14 = fd.ops.mul(T8, T13) T15 = fd.ops.cast(T14, dtype=DataType.Half) T16 = fd.ops.cast(T15, dtype=DataType.Float) -T17, T18 = fd.ops.var_mean(T16, axes=[2], correction=0, keepdim=False) +T17, T18 = fd.ops.var_mean(T16, dims=[2], correction=0, keepdim=False) T19 = fd.ops.broadcast_in_dim(T17, output_shape=[1, 1024, 1], broadcast_dims=[0, 1]) T20 = fd.ops.broadcast_in_dim(T18, output_shape=[1, 1024, 1], broadcast_dims=[0, 1]) S21 = fd.define_scalar(1.00000e-05) diff --git a/tools/gen_nvfuser_version.py b/tools/gen_nvfuser_version.py index 7537ff3ad4a..789aa96d37a 100644 --- a/tools/gen_nvfuser_version.py +++ b/tools/gen_nvfuser_version.py @@ -45,6 +45,22 @@ def get_pytorch_cmake_prefix(): return stdout_msg.decode("utf-8").rstrip("\n") +def get_pytorch_use_distributed(): + from subprocess import Popen, PIPE + + # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch + process_torch_prefix = Popen( + [ + sys.executable, + "-c", + "import torch; print(torch._C._has_distributed())", + ], + stdout=PIPE, + ) + stdout_msg, error_msg = process_torch_prefix.communicate() + return stdout_msg.decode("utf-8").rstrip("\n") + + if __name__ == "__main__": version_file = nvfuser_root / "nvfuser" / "version.py" with open(version_file, "w") as f: diff --git a/version.txt b/version.txt index 9faa1b7a733..c946ee6160c 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.5 +0.1.6