Skip to content

Commit

Permalink
Add default compiler options for v6e (#887)
Browse files Browse the repository at this point in the history
* Add default compiler options for v6e

* address PR comments

* use union

* make pytype happy
  • Loading branch information
samos123 authored Dec 12, 2024
1 parent 73625c9 commit 420fb66
Showing 1 changed file with 71 additions and 5 deletions.
76 changes: 71 additions & 5 deletions axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
# This module must not depend on any jax/axlearn modules so that
# importing this module does not result in initializing jax.
import re
from typing import Any, Union
from typing import Any, Dict, Union


def default_xla_options(
*, instance_type: str, num_slices: int, backend: str
) -> dict[str, Union[str, bool]]:
) -> dict[str, Union[str, bool, int]]:
"""Return the default flags for the given instance type and backend.
These options can be passed to `jitted_fn.lower(...).compile(compiler_options=...)`
Expand All @@ -31,7 +31,7 @@ def default_xla_options(
if backend != "tpu":
raise NotImplementedError(backend)
version = infer_tpu_version(infer_tpu_type(instance_type))
options = dict(
options: Dict[str, Union[int, str, bool]] = dict(
xla_tpu_spmd_rng_bit_generator_unsafe=True, # SPMD partition-aware RngBitGenerator.
xla_tpu_enable_latency_hiding_scheduler="true", # Try to schedule ops efficiently.
xla_tpu_perform_spmd_cse_prevention="false",
Expand All @@ -44,6 +44,68 @@ def default_xla_options(
xla_enable_async_all_gather="true", # Allow async all-gather.
xla_enable_async_collective_permute="true", # Allow async collective permute.
)
if version == "v6e":
options.update(
# Change to 16GB. The default is 4GB which is too small for larger models. This
# cause the step time to be double. You should increase this
# further if you see "Allocator failed to allocate". A feature
# to dynamically allocate may come later: b/380514965
megascale_grpc_premap_memory_bytes=17179869184,
# Improved performance for v6e.
xla_tpu_scoped_vmem_limit_kib=98304,
xla_tpu_enable_async_collective_fusion=True,
xla_tpu_enable_async_collective_fusion_fuse_all_gather=True,
xla_tpu_enable_async_collective_fusion_multiple_steps=True,
xla_tpu_overlap_compute_collective_tc=True,
xla_enable_async_all_gather=True,
# Host offloading flags
xla_tpu_enable_all_experimental_scheduler_features=True,
# Flag to enable memory tracking scheduling. The default AUTO only enables
# it in some situations. Not needed if
# xla_tpu_enable_all_experimental_scheduler_features is set to true already.
xla_tpu_enable_scheduler_memory_pressure_tracking=True,
# Flag controlling the maximum number of overlapping host offloadings.
xla_tpu_host_transfer_overlap_limit=24,
# Flag to enable the aggressive removal of opt-barriers.
xla_tpu_aggressive_opt_barrier_removal=True,
# Flag to enable more aggressive scheduling for async ops, such as pushing
# the async start to the beginning of the loop body.
xla_lhs_prioritize_async_depth_over_stall=True,
# Flag to enable pipelining of cross-DCN all-gathers.
xla_tpu_enable_ag_backward_pipelining=True,
xla_should_allow_loop_variant_parameter_in_chain=True,
xla_should_add_loop_invariant_op_in_chain=True,
# Flag controlling the maximum number of overlapping cross-DCN send/recv.
xla_max_concurrent_host_send_recv=100,
# Flag controlling the HBM memory limit as a percentage of the total HBM size.
# Default value is 95. Can tune up or down to give more or less memory for the
# scheduler. The scheduler favors more on less memory usage when it's under
# memory pressure, instead of hiding latency by overlapping more computations
# and communications.
xla_tpu_scheduler_percent_shared_memory_limit=90,
# Flag controlling the number of times the scheduler is run if the scheduled
# peak memory usage exceeds the initial memory limit, by setting memory limit
# to 90% of the previous memory limit each time. Default value is 1. Sometimes
# when the scheduler thinks it goes out memory, it may not actually happen due
# to other factors controlled by other compiler passes, or the initial memory
# limit is already set too low. Cutting the memory limit to 90% of previous one
# though, may make the scheduler weighting too much on the memory usage instead
# of latency side.
xla_latency_hiding_scheduler_rerun=2,
xla_tpu_use_enhanced_launch_barrier=True,
# Sparsecore offloading for all reduce.
# Uncomment below flags to enable it.
# xla_sc_disable_megacore_partitioning=True,
# xla_tpu_use_tc_device_shape_on_sc=True,
# tpu_use_continuations=True,
# xla_jf_crs_combiner_threshold_count=10,
# xla_sc_enable_instruction_fusion="false",
# xla_sc_disjoint_spmem="false",
# xla_tpu_enable_sparse_core_collective_offload_all_reduce=True,
)
# This flag can be removed after upgrading to Jax 0.4.38.
# Uncomment for sparsecore offloading.
# options["2a886c8_chip_config_name"] = "megachip_tccontrol"
if num_slices > 1:
# Support multiple TPU slices connected over a data center network.
options.update(
Expand All @@ -59,12 +121,16 @@ def default_xla_options(

# Validate options. Will never fail if this function is implemented correctly.
for k, v in options.items():
assert v in [True, False, "true", "false"], (k, v)
try:
int(v)
continue
except ValueError:
assert v in [True, False, "true", "false", "megachip_tccontrol"], (k, v)

return options


def xla_flags_from_options(xla_options: dict[str, Union[str, bool]]) -> str:
def xla_flags_from_options(xla_options: dict[str, Union[str, bool, int]]) -> str:
"""Convert an XLA options dict suitable for
`jitted_fn.lower(...).compile(compiler_options=xla_options)`
to XLA flags suitable for the `XLA_FLAGS` environment variable.
Expand Down

0 comments on commit 420fb66

Please sign in to comment.