diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index f0e5af7ac3c9a..c4e812e7e5aba 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -73,12 +73,72 @@ bool IsNopInstruction(const HloInstruction& hlo) { (op == HloOpcode::kTuple && hlo.user_count() == 1 && hlo.users().front()->opcode() == HloOpcode::kWhile); } + +bool InstructionDefinesValue(const HloInstruction* instruction, + const HloValue* value) { + if (value->defining_instruction() == instruction) { + return true; + } + if (value->shape().has_layout() && + value->shape().layout().memory_space() != kDefaultMemorySpace) { + return false; + } + // Also check if the instruction is a call to a computation that defines the + // value. This is needed in cases, e.g., where we wrap a value-defining + // instruction in a async call for offloading, and the async call itself will + // effectively define the value in the current scope that the scheduler is + // running in. + if (instruction->opcode() == HloOpcode::kAsyncStart || + instruction->opcode() == HloOpcode::kAsyncDone) { + if (instruction->async_wrapped_opcode() == HloOpcode::kCall) { + return instruction->async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() == value->defining_instruction(); + } + return instruction->async_wrapped_instruction() == + value->defining_instruction(); + } + return false; +} + +bool InstructionFirstDefinesBuffer( + const HloInstruction* instruction, + const BufferInfoTracker::ValueInfo& buffer_value_info) { + if (buffer_value_info.first_definition == instruction) { + return true; + } + if (buffer_value_info.value->values()[0]->shape().has_layout() && + buffer_value_info.value->values()[0]->shape().layout().memory_space() != + kDefaultMemorySpace) { + return false; + } + // Similar to logic above, also check if the instruction is a call to a + // computation that defines the value. + if (instruction->opcode() == HloOpcode::kAsyncStart || + instruction->opcode() == HloOpcode::kAsyncDone) { + if (instruction->async_wrapped_opcode() == HloOpcode::kCall) { + return instruction->async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() == buffer_value_info.first_definition; + } + return instruction->async_wrapped_instruction() == + buffer_value_info.first_definition; + } + return false; +} + } // namespace CanonicalAsyncOp DefaultGetCanonicalAsyncOp(const HloInstruction& hlo) { switch (hlo.opcode()) { case HloOpcode::kAsyncStart: case HloOpcode::kAsyncDone: + if (hlo.async_wrapped_opcode() == HloOpcode::kCall) { + return {hlo.opcode(), hlo.async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() + ->opcode()}; + } return {hlo.opcode(), hlo.async_wrapped_opcode()}; case HloOpcode::kAllReduceStart: return {HloOpcode::kAsyncStart, HloOpcode::kAllReduce}; @@ -596,7 +656,7 @@ void MemoryPressureTracker::Initialize( output_values.push_back(std::make_pair( buffer_tracker_.GetBufferInfo(buffer->id()), index)); if (absl::c_any_of(buffer->values(), [&](const HloValue* value) { - return value->defining_instruction() == instruction; + return InstructionDefinesValue(instruction, value); })) { defined_values.push_back( buffer_tracker_.GetBufferInfo(buffer->id())); @@ -663,7 +723,7 @@ void MemoryPressureTracker::UpdateBuffers(const HloInstruction* instruction) { continue; } if (live_buffers_[b.value->id()] != 0) { - if (b.first_definition == instruction) { + if (InstructionFirstDefinesBuffer(instruction, b)) { live_memory_usage_ -= b.buffer_size; live_buffers_set_.erase(b.value->id()); } @@ -721,7 +781,7 @@ std::pair MemoryPressureTracker::MemoryPressureDifference( continue; } if (live_buffers_[b.value->id()]) { - if (b.first_definition == instruction) { + if (InstructionFirstDefinesBuffer(instruction, b)) { increase -= b.buffer_size; } }