Skip to content

Commit

Permalink
PR #18448: Optionally don't inline stream annotated kCalls
Browse files Browse the repository at this point in the history
Imported from GitHub PR #18448

First part of splitting up #17982
Copybara import of the project:

--
8ecc06c by chaser <[email protected]>:

Don't inline stream annotated kCalls

Merging this change closes #18448

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18448 from chaserileyroberts:chase/stream_call_noinline 8ecc06c
PiperOrigin-RevId: 691070675
  • Loading branch information
chaserileyroberts authored and Google-ML-Automation committed Nov 7, 2024
1 parent 2ad465f commit 6b3cf9c
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 2 deletions.
1 change: 1 addition & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_nccl_p2p_max_nchannels(0);
opts.set_xla_gpu_multi_streamed_windowed_einsum(false);

opts.set_xla_gpu_experimental_stream_annotation(false);
// Minimum combined size of matrices in matrix multiplication to
// be rewritten to cuBLAS or Triton kernel call.
// This threshold is a conservative estimate and has been measured
Expand Down
2 changes: 2 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,7 @@ cc_library(
deps = [
":call_graph",
":hlo_domain_isolator",
"//xla:side_effect_util",
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
Expand Down Expand Up @@ -935,6 +936,7 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
Expand Down
17 changes: 16 additions & 1 deletion xla/service/call_inliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "xla/service/call_graph.h"
#include "xla/service/hlo_domain_isolator.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/side_effect_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -160,6 +161,19 @@ bool InlineComposites(
instruction->frontend_attributes().map().at("composite.name"));
}

bool InlineStreamAnnotation(HloInstruction* instruction) {
if (instruction->GetModule()
->config()
.debug_options()
.xla_gpu_experimental_stream_annotation()) {
if (instruction->frontend_attributes().map().contains(
kXlaStreamAnnotationAttr)) {
return false;
}
}
return true;
}

} // namespace

/* static */ absl::StatusOr<CallInliner::InlinedInstructionMap>
Expand Down Expand Up @@ -213,7 +227,8 @@ bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
!instruction->has_backend_config() &&
!instruction->parent()->IsAsyncComputation() &&
InlineUnderShardy(instruction) &&
InlineComposites(instruction, composites_to_preserve_);
InlineComposites(instruction, composites_to_preserve_) &&
InlineStreamAnnotation(instruction);
}

absl::StatusOr<bool> CallInliner::Run(
Expand Down
54 changes: 54 additions & 0 deletions xla/service/call_inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/test.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -494,5 +495,58 @@ TEST_F(CallInlinerTest, UseShardManualComputationBodySurroundedNotInlined) {
"my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234");
}

TEST_F(CallInlinerTest, DontInlineStreamAnnotationCall) {
const absl::string_view hlo_string = R"(
HloModule composite
%add (lhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] constant(2)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
%sub (lhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] constant(1)
ROOT %sub = f32[] subtract(f32[] %lhs, f32[] %rhs)
}
ENTRY %main () -> f32[] {
%lhs = f32[] constant(42)
%call1 = f32[] call(f32[] %lhs), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"}
ROOT %call2 = f32[] call(f32[] %call1), to_apply=%add
})";

auto debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_experimental_stream_annotation(true);
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
module->mutable_config().set_debug_options(debug_options);
CallInliner call_inliner(/*single_call_site=*/true);

TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
absl::StatusOr<bool> filecheck_result = RunFileCheck(module->ToString({}), R"(
//CHECK: %lhs.2 = f32[] constant(42)
//CHECK: %call1 = f32[] call(f32[] %lhs.2), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"}
//CHECK: %rhs.2 = f32[] constant(2)
//CHECK: ROOT %add.1 = f32[] add(f32[] %call1, f32[] %rhs.2)
)");
TF_ASSERT_OK(filecheck_result.status());
EXPECT_TRUE(*filecheck_result);

ASSERT_TRUE(mutated);
ASSERT_EQ(module->entry_computation()->instruction_count(), 4);
auto inst = module->entry_computation()->instructions().begin();
EXPECT_THAT(*inst, op::Constant());
// Check that the annotated call isn't inlined
++inst;
EXPECT_THAT(*inst, op::Call());

// Check that the non-annotated call is still inlined
++inst;
EXPECT_THAT(*inst, op::Constant());
++inst;
EXPECT_THAT(*inst, op::Add());
}

} // namespace
} // namespace xla
2 changes: 2 additions & 0 deletions xla/side_effect_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ const char kXlaBufferPlacementAttr[] = "_xla_buffer_placement";

const char kXlaBufferPlacementParam[] = "arg";

const char kXlaStreamAnnotationAttr[] = "_xla_stream_annotation";

const char kXlaCollectiveMatmulAttr[] = "_xla_collective_matmul";

const char kXlaCollectiveMatmulLhsAg[] = "lhs_ag";
Expand Down
3 changes: 3 additions & 0 deletions xla/side_effect_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ extern const char kXlaTableId[];
extern const char kXlaBufferPlacementAttr[];
extern const char kXlaBufferPlacementParam[];

// XLA frontend attribute for stream annotation.
extern const char kXlaStreamAnnotationAttr[];

// XLA frontend attribute for collective matmul control.
extern const char kXlaCollectiveMatmulAttr[];

Expand Down
4 changes: 3 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,8 @@ message DebugOptions {
// Whether to use multiple compute streams to run windowed einsum.
bool xla_gpu_multi_streamed_windowed_einsum = 280;

bool xla_gpu_experimental_stream_annotation = 342;

// If enabled, uses bf16_6way gemm to compute F32 gemm.
bool xla_gpu_enable_bf16_6way_gemm = 271;

Expand Down Expand Up @@ -1034,7 +1036,7 @@ message DebugOptions {
}
PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341;

// Next id: 342
// Next id: 343

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 6b3cf9c

Please sign in to comment.