diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 872a68ac7833a..f023843b818bd 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -256,6 +256,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_multi_streamed_windowed_einsum(false); + opts.set_xla_gpu_experimental_stream_annotation(false); // Minimum combined size of matrices in matrix multiplication to // be rewritten to cuBLAS or Triton kernel call. // This threshold is a conservative estimate and has been measured diff --git a/xla/service/BUILD b/xla/service/BUILD index af652b65ccc3a..53e840847a067 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -900,6 +900,7 @@ cc_library( deps = [ ":call_graph", ":hlo_domain_isolator", + "//xla:side_effect_util", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", @@ -934,6 +935,7 @@ xla_cc_test( "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", + "//xla/tests:filecheck", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", diff --git a/xla/service/call_inliner.cc b/xla/service/call_inliner.cc index 1fb8652110a77..b4818792f30e6 100644 --- a/xla/service/call_inliner.cc +++ b/xla/service/call_inliner.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "xla/side_effect_util.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -160,6 +161,19 @@ bool InlineComposites( instruction->frontend_attributes().map().at("composite.name")); } +bool InlineStreamAnnotation(HloInstruction* instruction) { + if (instruction->GetModule() + ->config() + .debug_options() + .xla_gpu_experimental_stream_annotation()) { + if (instruction->frontend_attributes().map().contains( + kXlaStreamAnnotationAttr)) { + return false; + } + } + return true; +} + } // namespace /* static */ absl::StatusOr @@ -213,7 +227,8 @@ bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const { !instruction->has_backend_config() && !instruction->parent()->IsAsyncComputation() && InlineUnderShardy(instruction) && - InlineComposites(instruction, composites_to_preserve_); + InlineComposites(instruction, composites_to_preserve_) && + InlineStreamAnnotation(instruction); } absl::StatusOr CallInliner::Run( diff --git a/xla/service/call_inliner_test.cc b/xla/service/call_inliner_test.cc index b41606d1a93e7..15c704b768394 100644 --- a/xla/service/call_inliner_test.cc +++ b/xla/service/call_inliner_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/filecheck.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -489,5 +490,58 @@ TEST_F(CallInlinerTest, UseShardManualComputationBodyInlined) { EXPECT_TRUE(changed); } +TEST_F(CallInlinerTest, DontInlineStreamAnnotationCall) { + const absl::string_view hlo_string = R"( + HloModule composite + + %add (lhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] constant(2) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) + } + + %sub (lhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] constant(1) + ROOT %sub = f32[] subtract(f32[] %lhs, f32[] %rhs) + } + + ENTRY %main () -> f32[] { + %lhs = f32[] constant(42) + %call1 = f32[] call(f32[] %lhs), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"} + ROOT %call2 = f32[] call(f32[] %call1), to_apply=%add + })"; + + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_experimental_stream_annotation(true); + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + module->mutable_config().set_debug_options(debug_options); + CallInliner call_inliner(/*single_call_site=*/true); + + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + absl::StatusOr filecheck_result = RunFileCheck(module->ToString({}), R"( + //CHECK: %lhs.2 = f32[] constant(42) + //CHECK: %call1 = f32[] call(f32[] %lhs.2), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"} + //CHECK: %rhs.2 = f32[] constant(2) + //CHECK: ROOT %add.1 = f32[] add(f32[] %call1, f32[] %rhs.2) + )"); + TF_ASSERT_OK(filecheck_result.status()); + EXPECT_TRUE(*filecheck_result); + + ASSERT_TRUE(mutated); + ASSERT_EQ(module->entry_computation()->instruction_count(), 4); + auto inst = module->entry_computation()->instructions().begin(); + EXPECT_THAT(*inst, op::Constant()); + // Check that the annotated call isn't inlined + ++inst; + EXPECT_THAT(*inst, op::Call()); + + // Check that the non-annotated call is still inlined + ++inst; + EXPECT_THAT(*inst, op::Constant()); + ++inst; + EXPECT_THAT(*inst, op::Add()); +} + } // namespace } // namespace xla diff --git a/xla/side_effect_util.cc b/xla/side_effect_util.cc index 18e0144d863b5..f874bd4a5c6f3 100644 --- a/xla/side_effect_util.cc +++ b/xla/side_effect_util.cc @@ -59,6 +59,8 @@ const char kXlaBufferPlacementAttr[] = "_xla_buffer_placement"; const char kXlaBufferPlacementParam[] = "arg"; +const char kXlaStreamAnnotationAttr[] = "_xla_stream_annotation"; + const char kXlaCollectiveMatmulAttr[] = "_xla_collective_matmul"; const char kXlaCollectiveMatmulLhsAg[] = "lhs_ag"; diff --git a/xla/side_effect_util.h b/xla/side_effect_util.h index f16949fff635b..13a74a46d5a00 100644 --- a/xla/side_effect_util.h +++ b/xla/side_effect_util.h @@ -67,6 +67,9 @@ extern const char kXlaTableId[]; extern const char kXlaBufferPlacementAttr[]; extern const char kXlaBufferPlacementParam[]; +// XLA frontend attribute for stream annotation. +extern const char kXlaStreamAnnotationAttr[]; + // XLA frontend attribute for collective matmul control. extern const char kXlaCollectiveMatmulAttr[]; diff --git a/xla/xla.proto b/xla/xla.proto index 9c75901e3ab61..e15c240f32259 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -817,6 +817,8 @@ message DebugOptions { // Whether to use multiple compute streams to run windowed einsum. bool xla_gpu_multi_streamed_windowed_einsum = 280; + bool xla_gpu_experimental_stream_annotation = 342; + // If enabled, uses bf16_6way gemm to compute F32 gemm. bool xla_gpu_enable_bf16_6way_gemm = 271; @@ -1048,7 +1050,7 @@ message DebugOptions { } PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341; - // Next id: 342 + // Next id: 343 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.