diff --git a/xla/service/BUILD b/xla/service/BUILD index 06be09542b2a0..07e0cf276331b 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1992,7 +1992,6 @@ cc_library( deps = [ ":buffer_assignment_proto_cc", ":buffer_value", - ":buffer_value_containers", ":call_graph", ":hlo_alias_analysis", ":hlo_buffer", @@ -2003,7 +2002,6 @@ cc_library( ":logical_buffer", "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", @@ -2013,6 +2011,7 @@ cc_library( "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -2023,6 +2022,7 @@ cc_library( "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:numbers", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/buffer_assignment.cc b/xla/service/buffer_assignment.cc index 6cb157fb5a661..75b2dc8263f69 100644 --- a/xla/service/buffer_assignment.cc +++ b/xla/service/buffer_assignment.cc @@ -18,42 +18,54 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include +#include #include #include #include #include #include #include +#include +#include #include #include #include "absl/algorithm/container.h" -#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_op_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/map_util.h" #include "xla/service/buffer_value.h" -#include "xla/service/buffer_value_containers.h" +#include "xla/service/call_graph.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" +#include "xla/service/logical_buffer.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -91,8 +103,8 @@ BuildIdToLogicalBufferMap( << "Expected logical buffer to have location information in the proto."; TF_RET_CHECK(id_to_hlo_instruction.contains( logical_buffer_proto.defined_at().instruction_id())) - << "Expected hlo instruction " - << "with the id '" << logical_buffer_proto.defined_at().instruction_id() + << "Expected hlo instruction " << "with the id '" + << logical_buffer_proto.defined_at().instruction_id() << "' in the proto to also exist in the " "HLO module."; // Assumption: An hlo module loaded from an hlo proto @@ -199,7 +211,7 @@ absl::Status GatherComputationsByAllocationType( break; default: return Internal("Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode())); + HloOpcodeString(instruction->opcode())); } } } @@ -330,9 +342,14 @@ static const HloInstruction* GetOutputInstruction( return nullptr; } -std::string BufferAllocation::ToShortString() const { +std::string BufferAllocation::ToShortString(bool human_readable_size) const { std::string output; - StrAppendFormat(&output, "allocation %d: size %d", index_, size()); + if (human_readable_size) { + StrAppendFormat(&output, "allocation %d: size %s", index_, + HumanReadableNumBytes(size())); + } else { + StrAppendFormat(&output, "allocation %d: size %d", index_, size()); + } if (color() != 0) { StrAppend(&output, ", color ", color()); } @@ -828,6 +845,151 @@ std::string BufferAssignment::ToString() const { return output; } +std::string BufferAssignment::MemoryUsageReport(float percentile, + int64_t more_than_k) const { + std::string output; + int64_t total_size = 0; + for (auto& allocation : allocations_) { + total_size += allocation.size(); + } + absl::StrAppend(&output, "Total bytes used: ", total_size, " (", + HumanReadableNumBytes(total_size), ")\n"); + + absl::StrAppend(&output, "\nAllocations sorted by size:\n\n"); + auto allocations = allocations_; + std::sort(allocations.begin(), allocations.end(), + [](const BufferAllocation& a, const BufferAllocation& b) { + if (a.size() > b.size()) return true; + if (a.size() < b.size()) return false; + return a.index() < b.index(); + }); + + int64_t cumulative_size = 0; + absl::StrAppend( + &output, "cumulative_size; total_size - cumulative_size; allocation\n"); + absl::StrAppend(&output, + "------------------------------------------------------------" + "------------------\n"); + int64_t index = 0; + for (auto& allocation : allocations) { + cumulative_size += allocation.size(); + absl::StrAppend( + &output, + absl::StrFormat("%10s(%3.0f%%); %10s; %s", + HumanReadableNumBytes(cumulative_size), + 100. * cumulative_size / total_size, + HumanReadableNumBytes(total_size - cumulative_size), + allocation.ToShortString(true))); + + // Skip the rest of the allocations if they are less than percentile of the + // total size and not more than k. + if (++index > more_than_k && + total_size - cumulative_size < total_size * percentile) { + absl::StrAppend( + &output, + absl::StrFormat( + "The rest %d allocations are less than %d%% of the total " + "size and not shown.\n", + allocations.size() - index, static_cast(percentile * 100))); + break; + } + } + + absl::StrAppend(&output, + "\n\nAllocations sorted by size with their values:\n"); + for (auto& allocation : allocations) { + if (allocation.assigned_buffers().size() == 1) { + absl::StrAppend(&output, allocation.ToShortString(true)); + } else { + StrAppendFormat( + &output, "%s\n%s\n", allocation.ToShortString(true), + allocation.MemoryUsageReport("\t", percentile, more_than_k)); + } + } + return output; +} + +std::string BufferAllocation::MemoryUsageReport(const std::string& prefix, + float percentile, + int64_t more_than_k) const { + std::string output; + + struct OffsetInfo { + std::vector values; + OffsetSize offset_size; + }; + + // Group the values by their offset in the allocation. + absl::flat_hash_map offset_to_buffers; + for (const auto& element : assigned_buffers_) { + const HloValue* value = element.first; + OffsetInfo& offset_info = offset_to_buffers[element.second.offset]; + offset_info.values.push_back(value); + offset_info.offset_size.offset = element.second.offset; + offset_info.offset_size.size = + std::max(offset_info.offset_size.size, element.second.size); + } + + // Sort the offset infos by the max size of the values in the group. + std::vector sorted_offset_infos; + int64_t total_size = 0; + for (auto& element : offset_to_buffers) { + total_size += element.second.offset_size.size; + sorted_offset_infos.push_back(std::move(element.second)); + } + absl::c_sort(sorted_offset_infos, + [](const OffsetInfo& a, const OffsetInfo& b) { + return a.offset_size.size > b.offset_size.size; + }); + + StrAppend(&output, prefix, + "cumulative_size; size; offset; used_by_n_values; " + "shapes_list\n"); + StrAppend(&output, prefix, + "------------------------------------------------------------\n"); + int64_t cumulative_size = 0; + int64_t index = 0; + for (const auto& offset_info : sorted_offset_infos) { + cumulative_size += offset_info.offset_size.size; + StrAppendFormat(&output, "%s%9s(%3.0f%%); %10s; %12d; %16d; ", prefix, + xla::HumanReadableNumBytes(cumulative_size), + 100. * cumulative_size / total_size, + xla::HumanReadableNumBytes(offset_info.offset_size.size), + offset_info.offset_size.offset, offset_info.values.size()); + + // Count the number of values with the same shape and append them at the end + // of the line. + absl::flat_hash_map shapes; + for (auto& value : offset_info.values) shapes[value->shape().ToString()]++; + + StrAppend( + &output, + absl::StrJoin(shapes, ", ", [](std::string* out, const auto& pair) { + if (pair.second == 1) { + return absl::StrAppend(out, pair.first); + } + return absl::StrAppend(out, pair.second, "×", pair.first); + })); + + StrAppend(&output, "\n"); + + // Skip the rest of the values if they are less than percentile of the + // total size and not more than k. + if (++index > more_than_k && + total_size - cumulative_size < total_size * percentile) { + StrAppendFormat( + &output, + "%sThe rest %d values are less than %d%% of the total size and not " + "shown.\n", + prefix, sorted_offset_infos.size() - index, + static_cast(percentile * 100)); + break; + } + } + + return output; +} + // Returns the largest k buffers present at the point of peak memory usage // across allocations as a vector of pairs with their corresponding sizes. std::vector> TopKPeakBuffers( diff --git a/xla/service/buffer_assignment.h b/xla/service/buffer_assignment.h index 337d9faa9f64a..fbffc282eb7f0 100644 --- a/xla/service/buffer_assignment.h +++ b/xla/service/buffer_assignment.h @@ -228,7 +228,18 @@ class BufferAllocation { Slice GetSlice(const HloValue& buffer) const; std::string ToString() const; - std::string ToShortString() const; + std::string ToShortString(bool human_readable_size = false) const; + std::string ValuesToString() const; + + // The function returns memory usage report for the values belonging to the + // buffer allocation. The values are grouped by their offset in the + // allocation. The groups are sorted by the max size(Z-A) of the values in the + // group. Percentile and more_than_k are used to control the number of groups + // being reported. + std::string MemoryUsageReport(const std::string& prefix, + float percentile = 0.05, + int64_t more_than_k = 50) const; + BufferAllocationProto ToProto() const; // Whether the buffer is a parameter to or live out of the entry computation. @@ -486,10 +497,18 @@ class BufferAssignment { // Returns the HloLiveRange object used to construct this assignment. const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } + // Is in use by many compilers to dump the buffer-assignment info. std::string ToString() const; + + // Returns a memory usage report with the list of buffer allocations ordered + // by the size(Z-A) and the values assigned to each buffer allocation. + std::string MemoryUsageReport(float percentile = 0.05, + int64_t more_than_k = 50) const; // Verbose string tailored to debugging OOMs, includes the Hlo op metadata for // every buffer associated with each allocation. std::string ToVerboseString(size_t max_buffers_to_show) const; + + // Is in use by tpu compiler to dump the buffer info. std::string BufferInfoString() const; // Convert BufferAssignment to or from a proto. diff --git a/xla/service/dump.cc b/xla/service/dump.cc index fcfbc22159d4a..09153157cb3a8 100644 --- a/xla/service/dump.cc +++ b/xla/service/dump.cc @@ -460,13 +460,17 @@ static std::vector DumpHloModuleImpl( file_paths.push_back(DumpToFileInDirOrStdoutImpl( StrCat(filename, ".txt"), module.ToString(print_options), opts)); if (buffer_assn) { - DataProducer data_producer; - data_producer.Append([&] { return buffer_assn->ToString(); }); - data_producer.Append([&] { return "\n\n"; }); - data_producer.Append( + DataProducer buffer_assignment; + buffer_assignment.Append([&] { return buffer_assn->ToString(); }); + buffer_assignment.Append([&] { return "\n\n"; }); + buffer_assignment.Append( [&] { return buffer_assn->hlo_live_range().ToString(); }); file_paths.push_back(DumpToFileInDirOrStdoutImpl( - StrCat(filename, "-buffer-assignment.txt"), data_producer, opts)); + StrCat(filename, "-buffer-assignment.txt"), buffer_assignment, opts)); + DataProducer summary_report; + summary_report.Append([&] { return buffer_assn->MemoryUsageReport(); }); + file_paths.push_back(DumpToFileInDirOrStdoutImpl( + StrCat(filename, "-memory-usage-report.txt"), summary_report, opts)); } }