Skip to content

Commit

Permalink
Add default compiler options for v6e
Browse files Browse the repository at this point in the history
  • Loading branch information
samos123 committed Dec 11, 2024
1 parent 73625c9 commit 655c539
Showing 1 changed file with 69 additions and 4 deletions.
73 changes: 69 additions & 4 deletions axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
# This module must not depend on any jax/axlearn modules so that
# importing this module does not result in initializing jax.
import re
from typing import Any, Union
from typing import Any, Dict, Union


def default_xla_options(
*, instance_type: str, num_slices: int, backend: str
) -> dict[str, Union[str, bool]]:
) -> dict[str, Union[str, bool, int]]:
"""Return the default flags for the given instance type and backend.
These options can be passed to `jitted_fn.lower(...).compile(compiler_options=...)`
Expand All @@ -31,7 +31,7 @@ def default_xla_options(
if backend != "tpu":
raise NotImplementedError(backend)
version = infer_tpu_version(infer_tpu_type(instance_type))
options = dict(
options: Dict[str, int | str | bool] = dict(
xla_tpu_spmd_rng_bit_generator_unsafe=True, # SPMD partition-aware RngBitGenerator.
xla_tpu_enable_latency_hiding_scheduler="true", # Try to schedule ops efficiently.
xla_tpu_perform_spmd_cse_prevention="false",
Expand All @@ -44,6 +44,62 @@ def default_xla_options(
xla_enable_async_all_gather="true", # Allow async all-gather.
xla_enable_async_collective_permute="true", # Allow async collective permute.
)
if version == "v6e":
options.update(
# Improved performance for v6e.
xla_tpu_scoped_vmem_limit_kib=98304,
xla_tpu_enable_async_collective_fusion="true",
xla_tpu_enable_async_collective_fusion_fuse_all_gather="true",
xla_tpu_enable_async_collective_fusion_multiple_steps="true",
xla_tpu_overlap_compute_collective_tc="true",
xla_enable_async_all_gather="true",
# Host offloading flags
xla_tpu_enable_all_experimental_scheduler_features="true",
# Flag to enable memory tracking scheduling. The default AUTO only enables
# it in some situations. Not needed if
# xla_tpu_enable_all_experimental_scheduler_features is set to true already.
xla_tpu_enable_scheduler_memory_pressure_tracking="true",
# Flag controlling the maximum number of overlapping host offloadings.
xla_tpu_host_transfer_overlap_limit=24,
# Flag to enable the aggressive removal of opt-barriers.
xla_tpu_aggressive_opt_barrier_removal="true",
# Flag to enable more aggressive scheduling for async ops, such as pushing
# the async start to the beginning of the loop body.
xla_lhs_prioritize_async_depth_over_stall="true",
# Flag to enable pipelining of cross-DCN all-gathers.
xla_tpu_enable_ag_backward_pipelining="true",
xla_should_allow_loop_variant_parameter_in_chain="true",
xla_should_add_loop_invariant_op_in_chain="true",
# Flag controlling the maximum number of overlapping cross-DCN send/recv.
xla_max_concurrent_host_send_recv=100,
# Flag controlling the HBM memory limit as a percentage of the total HBM size.
# Default value is 95. Can tune up or down to give more or less memory for the
# scheduler. The scheduler favors more on less memory usage when it's under
# memory pressure, instead of hiding latency by overlapping more computations
# and communications.
xla_tpu_scheduler_percent_shared_memory_limit=90,
# Flag controlling the number of times the scheduler is run if the scheduled
# peak memory usage exceeds the initial memory limit, by setting memory limit
# to 90% of the previous memory limit each time. Default value is 1. Sometimes
# when the scheduler thinks it goes out memory, it may not actually happen due
# to other factors controlled by other compiler passes, or the initial memory
# limit is already set too low. Cutting the memory limit to 90% of previous one
# though, may make the scheduler weighting too much on the memory usage instead
# of latency side.
xla_latency_hiding_scheduler_rerun=2,
xla_tpu_use_enhanced_launch_barrier="true",
# Sparsecore offloading for all reduce.
# Uncomment below flags to enable it.
# xla_sc_disable_megacore_partitioning="true",
# xla_tpu_use_tc_device_shape_on_sc="true",
# tpu_use_continuations="true",
# xla_jf_crs_combiner_threshold_count=10,
# xla_sc_enable_instruction_fusion="false",
# xla_sc_disjoint_spmem="false",
# xla_tpu_enable_sparse_core_collective_offload_all_reduce="true",
)
# This flag can be removed after upgrading to Jax 0.4.38.
options["2a886c8_chip_config_name"] = "megachip_tccontrol"
if num_slices > 1:
# Support multiple TPU slices connected over a data center network.
options.update(
Expand All @@ -55,11 +111,20 @@ def default_xla_options(
xla_tpu_data_parallel_opt_different_sized_ops="true",
# Group non-blocking DCN collectives into as few stages as possible.
xla_tpu_enable_sunk_dcn_allreduce_done_with_host_reduction="true",
# Change to 16GB. The default is 4GB which is too small for larger models. This
# cause the step time to be double. You should increase this
# further if you see "Allocator failed to allocate". A feature
# to dynamically allocate may come later: b/380514965
megascale_grpc_premap_memory_bytes=17179869184,
)

# Validate options. Will never fail if this function is implemented correctly.
for k, v in options.items():
assert v in [True, False, "true", "false"], (k, v)
try:
int(v)
continue
except ValueError:
assert v in [True, False, "true", "false", "megachip_tccontrol"], (k, v)

return options

Expand Down

0 comments on commit 655c539

Please sign in to comment.