Skip to content

Commit

Permalink
Add logic to track definitions across calls in the scheduler.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678948517
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
1 parent 03b93df commit 88648d3
Showing 1 changed file with 63 additions and 3 deletions.
66 changes: 63 additions & 3 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,72 @@ bool IsNopInstruction(const HloInstruction& hlo) {
(op == HloOpcode::kTuple && hlo.user_count() == 1 &&
hlo.users().front()->opcode() == HloOpcode::kWhile);
}

bool InstructionDefinesValue(const HloInstruction* instruction,
const HloValue* value) {
if (value->defining_instruction() == instruction) {
return true;
}
if (value->shape().has_layout() &&
value->shape().layout().memory_space() != kDefaultMemorySpace) {
return false;
}
// Also check if the instruction is a call to a computation that defines the
// value. This is needed in cases, e.g., where we wrap a value-defining
// instruction in a async call for offloading, and the async call itself will
// effectively define the value in the current scope that the scheduler is
// running in.
if (instruction->opcode() == HloOpcode::kAsyncStart ||
instruction->opcode() == HloOpcode::kAsyncDone) {
if (instruction->async_wrapped_opcode() == HloOpcode::kCall) {
return instruction->async_wrapped_instruction()
->called_computations()[0]
->root_instruction() == value->defining_instruction();
}
return instruction->async_wrapped_instruction() ==
value->defining_instruction();
}
return false;
}

bool InstructionFirstDefinesBuffer(
const HloInstruction* instruction,
const BufferInfoTracker::ValueInfo& buffer_value_info) {
if (buffer_value_info.first_definition == instruction) {
return true;
}
if (buffer_value_info.value->values()[0]->shape().has_layout() &&
buffer_value_info.value->values()[0]->shape().layout().memory_space() !=
kDefaultMemorySpace) {
return false;
}
// Similar to logic above, also check if the instruction is a call to a
// computation that defines the value.
if (instruction->opcode() == HloOpcode::kAsyncStart ||
instruction->opcode() == HloOpcode::kAsyncDone) {
if (instruction->async_wrapped_opcode() == HloOpcode::kCall) {
return instruction->async_wrapped_instruction()
->called_computations()[0]
->root_instruction() == buffer_value_info.first_definition;
}
return instruction->async_wrapped_instruction() ==
buffer_value_info.first_definition;
}
return false;
}

} // namespace

CanonicalAsyncOp DefaultGetCanonicalAsyncOp(const HloInstruction& hlo) {
switch (hlo.opcode()) {
case HloOpcode::kAsyncStart:
case HloOpcode::kAsyncDone:
if (hlo.async_wrapped_opcode() == HloOpcode::kCall) {
return {hlo.opcode(), hlo.async_wrapped_instruction()
->called_computations()[0]
->root_instruction()
->opcode()};
}
return {hlo.opcode(), hlo.async_wrapped_opcode()};
case HloOpcode::kAllReduceStart:
return {HloOpcode::kAsyncStart, HloOpcode::kAllReduce};
Expand Down Expand Up @@ -596,7 +656,7 @@ void MemoryPressureTracker::Initialize(
output_values.push_back(std::make_pair(
buffer_tracker_.GetBufferInfo(buffer->id()), index));
if (absl::c_any_of(buffer->values(), [&](const HloValue* value) {
return value->defining_instruction() == instruction;
return InstructionDefinesValue(instruction, value);
})) {
defined_values.push_back(
buffer_tracker_.GetBufferInfo(buffer->id()));
Expand Down Expand Up @@ -663,7 +723,7 @@ void MemoryPressureTracker::UpdateBuffers(const HloInstruction* instruction) {
continue;
}
if (live_buffers_[b.value->id()] != 0) {
if (b.first_definition == instruction) {
if (InstructionFirstDefinesBuffer(instruction, b)) {
live_memory_usage_ -= b.buffer_size;
live_buffers_set_.erase(b.value->id());
}
Expand Down Expand Up @@ -721,7 +781,7 @@ std::pair<int64_t, int64_t> MemoryPressureTracker::MemoryPressureDifference(
continue;
}
if (live_buffers_[b.value->id()]) {
if (b.first_definition == instruction) {
if (InstructionFirstDefinesBuffer(instruction, b)) {
increase -= b.buffer_size;
}
}
Expand Down

0 comments on commit 88648d3

Please sign in to comment.