From 564455c9cf78edf8e3dd76343b74b2882e86a2d6 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 11 Dec 2024 12:22:58 -0800 Subject: [PATCH] Add default compiler options for v6e --- axlearn/common/compiler_options.py | 71 ++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index 77ab656f..fb4c1034 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -4,12 +4,12 @@ # This module must not depend on any jax/axlearn modules so that # importing this module does not result in initializing jax. import re -from typing import Any, Union +from typing import Any, Dict, Union def default_xla_options( *, instance_type: str, num_slices: int, backend: str -) -> dict[str, Union[str, bool]]: +) -> dict[str, Union[str, bool, int]]: """Return the default flags for the given instance type and backend. These options can be passed to `jitted_fn.lower(...).compile(compiler_options=...)` @@ -31,7 +31,7 @@ def default_xla_options( if backend != "tpu": raise NotImplementedError(backend) version = infer_tpu_version(infer_tpu_type(instance_type)) - options = dict( + options: Dict[str, int | str | bool] = dict( xla_tpu_spmd_rng_bit_generator_unsafe=True, # SPMD partition-aware RngBitGenerator. xla_tpu_enable_latency_hiding_scheduler="true", # Try to schedule ops efficiently. xla_tpu_perform_spmd_cse_prevention="false", @@ -44,6 +44,60 @@ def default_xla_options( xla_enable_async_all_gather="true", # Allow async all-gather. xla_enable_async_collective_permute="true", # Allow async collective permute. ) + if version == "v6e": + options.update( + # Improved performance for v6e. + xla_tpu_scoped_vmem_limit_kib=98304, + xla_tpu_enable_async_collective_fusion="true", + xla_tpu_enable_async_collective_fusion_fuse_all_gather="true", + xla_tpu_enable_async_collective_fusion_multiple_steps="true", + xla_tpu_overlap_compute_collective_tc="true", + xla_enable_async_all_gather="true", + # Host offloading flags + xla_tpu_enable_all_experimental_scheduler_features="true", + # Flag to enable memory tracking scheduling. The default AUTO only enables + # it in some situations. Not needed if + # xla_tpu_enable_all_experimental_scheduler_features is set to true already. + xla_tpu_enable_scheduler_memory_pressure_tracking="true", + # Flag controlling the maximum number of overlapping host offloadings. + xla_tpu_host_transfer_overlap_limit=24, + # Flag to enable the aggressive removal of opt-barriers. + xla_tpu_aggressive_opt_barrier_removal="true", + # Flag to enable more aggressive scheduling for async ops, such as pushing + # the async start to the beginning of the loop body. + xla_lhs_prioritize_async_depth_over_stall="true", + # Flag to enable pipelining of cross-DCN all-gathers. + xla_tpu_enable_ag_backward_pipelining="true", + xla_should_allow_loop_variant_parameter_in_chain="true", + xla_should_add_loop_invariant_op_in_chain="true", + # Flag controlling the maximum number of overlapping cross-DCN send/recv. + xla_max_concurrent_host_send_recv=100, + # Flag controlling the HBM memory limit as a percentage of the total HBM size. + # Default value is 95. Can tune up or down to give more or less memory for the + # scheduler. The scheduler favors more on less memory usage when it's under + # memory pressure, instead of hiding latency by overlapping more computations + # and communications. + xla_tpu_scheduler_percent_shared_memory_limit=90, + # Flag controlling the number of times the scheduler is run if the scheduled + # peak memory usage exceeds the initial memory limit, by setting memory limit + # to 90% of the previous memory limit each time. Default value is 1. Sometimes + # when the scheduler thinks it goes out memory, it may not actually happen due + # to other factors controlled by other compiler passes, or the initial memory + # limit is already set too low. Cutting the memory limit to 90% of previous one + # though, may make the scheduler weighting too much on the memory usage instead + # of latency side. + xla_latency_hiding_scheduler_rerun=2, + xla_tpu_use_enhanced_launch_barrier="true", + # Sparsecore offloading for all reduce. + xla_sc_disable_megacore_partitioning="true", + xla_tpu_use_tc_device_shape_on_sc="true", + tpu_use_continuations="true", + xla_jf_crs_combiner_threshold_count=10, + xla_sc_enable_instruction_fusion="false", + xla_sc_disjoint_spmem="false", + xla_tpu_enable_sparse_core_collective_offload_all_reduce="true", + ) + options["2a886c8_chip_config_name"] = "megachip_tccontrol" if num_slices > 1: # Support multiple TPU slices connected over a data center network. options.update( @@ -55,11 +109,20 @@ def default_xla_options( xla_tpu_data_parallel_opt_different_sized_ops="true", # Group non-blocking DCN collectives into as few stages as possible. xla_tpu_enable_sunk_dcn_allreduce_done_with_host_reduction="true", + # Change to 16GB. The default is 4GB which is too small for larger models. This + # cause the step time to be double. You should increase this + # further if you see "Allocator failed to allocate". A feature + # to dynamically allocate may come later: b/380514965 + megascale_grpc_premap_memory_bytes=17179869184, ) # Validate options. Will never fail if this function is implemented correctly. for k, v in options.items(): - assert v in [True, False, "true", "false"], (k, v) + try: + int(v) + continue + except ValueError: + assert v in [True, False, "true", "false", "megachip_tccontrol"], (k, v) return options