Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default compiler options for v6e #887

Merged
merged 4 commits into from
Dec 12, 2024

Conversation

samos123
Copy link
Contributor

No description provided.

@samos123 samos123 force-pushed the v6e-compiler-options-clean branch from 564455c to 7ca21ff Compare December 11, 2024 22:09
@samos123 samos123 force-pushed the v6e-compiler-options-clean branch from 7ca21ff to 655c539 Compare December 11, 2024 23:23
Copy link
Contributor

@Ethanlm Ethanlm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@samos123 samos123 requested a review from Ethanlm December 12, 2024 01:34
@Ethanlm Ethanlm added this pull request to the merge queue Dec 12, 2024
Merged via the queue into apple:main with commit 420fb66 Dec 12, 2024
6 checks passed
# Improved performance for v6e.
xla_tpu_scoped_vmem_limit_kib=98304,
xla_tpu_enable_async_collective_fusion=True,
xla_tpu_enable_async_collective_fusion_fuse_all_gather=True,
Copy link
Contributor

@Ethanlm Ethanlm Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs These need to be "true" instead of True, otherwise it will be converted to 1 by xla_flags_from_options and then fail

I1212 21:41:11.645763 132689403348992 launch.py:112] LIBTPU_INIT_ARGS='--xla_tpu_spmd_rng_bit_generator_unsafe=1 --xla_tpu_enable_latency_hiding_scheduler=true --xla_tpu_perform_spmd_cse_prevention=false --megascale_grpc_premap_memory_bytes=17179869184 --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_enable_async_collective_fusion=1 --xla_tpu_enable_async_collective_fusion_fuse_all_gather=1 --xla_tpu_enable_async_collective_fusion_multiple_steps=1 --xla_tpu_overlap_compute_collective_tc=1 --xla_enable_async_all_gather=1 --xla_tpu_enable_all_experimental_scheduler_features=1 --xla_tpu_enable_scheduler_memory_pressure_tracking=1 --xla_tpu_host_transfer_overlap_limit=24 --xla_tpu_aggressive_opt_barrier_removal=1 --xla_lhs_prioritize_async_depth_over_stall=1 --xla_tpu_enable_ag_backward_pipelining=1 --xla_should_allow_loop_variant_parameter_in_chain=1 --xla_should_add_loop_invariant_op_in_chain=1 --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=90 --xla_latency_hiding_scheduler_rerun=2 --xla_tpu_use_enhanced_launch_barrier=1'
2024-12-12 21:41:14.669321: I external/tsl/tsl/platform/default/grpc_credentials.cc:30] gRPC insecure client credentials are used.
I1212 21:41:14.670917 132689403348992 distributed.py:119] Connecting to JAX distributed service on ethanli-fuji-70b-v2-test1-job-0-0.ethanli-fuji-70b-v2-test1:8476
2024-12-12 21:41:14.999583: I external/xla/xla/pjrt/distributed/client.cc:135] Connected to distributed JAX controller
ERROR: Illegal value '1' specified for flag 'xla_tpu_enable_async_collective_fusion_fuse_all_gather'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_enable_async_all_gather'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_tpu_enable_scheduler_memory_pressure_tracking'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_tpu_aggressive_opt_barrier_removal'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_lhs_prioritize_async_depth_over_stall'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_should_allow_loop_variant_parameter_in_chain'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_should_add_loop_invariant_op_in_chain'; expected one of true/enabled, false/disabled or auto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's what I had originally, but pytype doesn't like that and will fail. So I see 2 options:

  1. Change behavior of axlearn that converts True boolean to 1 and let it return "true" instead.
  2. Ignore pytype check

xla_tpu_use_enhanced_launch_barrier=True,
# Sparsecore offloading for all reduce.
# Uncomment below flags to enable it.
# xla_sc_disable_megacore_partitioning=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ethanlm do you recall why you commented all them off?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We weren't sure if this would perform better for everyone. It's a newer feature as well. So Ethan recommended not enabling by default. For both llama 2 70b and 405b, we saw much better performance with it being enabled though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants