Skip to content

Commit

Permalink
Merge branch 'main' into mamba2
Browse files Browse the repository at this point in the history
  • Loading branch information
berlino authored Dec 4, 2024
2 parents fd6b372 + 2056857 commit 1469a35
Show file tree
Hide file tree
Showing 132 changed files with 5,769 additions and 3,349 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: build-and-test
on: [pull_request, push]

jobs:
build-and-test-job:
runs-on: ubuntu-latest
strategy:
matrix:
test-group: [a, b, c, d, e]
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
- name: Gather test files
run: find axlearn -name '*_test.py' > pytest_files.txt
- name: Split test files into groups
# GNU split lets us do "-n r/5" to round robin into 5 files without breaking lines
# BSD split requires knowing the number of lines and uses "-l XX"
run: split -n r/5 -a 1 pytest_files.txt split_pytest_files
- name: Select a test group
run: tr '\n' ' ' < split_pytest_files${{ matrix.test-group }} > test_files_oneline
- name: Read test inputs
id: test-selector
run: echo "PYTEST_FILES='$(cat test_files_oneline)'" >> "$GITHUB_OUTPUT"
- name: Run tests
uses: docker/build-push-action@v6
with:
push: false
target: ci
context: .
build-args: |
SKIP_PRECOMMIT=--skip-pre-commit
PYTEST_FILES=${{ steps.test-selector.outputs.PYTEST_FILES }}
20 changes: 20 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: pre-commit

on: pull_request

jobs:
pre-commit:
runs-on: ubuntu-latest
# resource_class: large
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: 'pip'
- run: pip install --upgrade pip
# TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype)
- run: pip install '.[core,dev,grain,gcp,vertexai_tensorboard]'
# pylint uses approx 12GB of memory during this run, look into split to decrease?
- run: pre-commit run --all-files
- run: pytype -j auto .
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ RUN pip install --upgrade pip
################################################################################

# Leverage multi-stage build for unit tests.
FROM base as ci
FROM base AS ci

# TODO(markblee): Remove gcp,vertexai_tensorboard from CI.
RUN pip install .[core,dev,grain,gcp,vertexai_tensorboard]
Expand Down
6 changes: 3 additions & 3 deletions axlearn/audio/decoder_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ def forward(self, input_batch: Nested[Tensor]) -> tuple[Tensor, Nested[Tensor]]:
paddings=input_batch["paddings"]
)
predict_outputs = self.decoder(
input_ids=input_batch["target"]["input_ids"],
input_batch=input_batch["target"],
cross_attention_data=input_batch["inputs"],
cross_attention_logit_biases=cross_attention_logit_biases,
)
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def beam_search_decode(

with child_context("beam_search_decode", module=self.decoder):
beam_search_outputs: decoding.BeamSearchOutputs = self.decoder.beam_search_decode(
prefix=input_batch["prefix"],
input_batch=input_batch,
max_sequence_length=max_decode_len,
num_decodes=num_decodes,
cross_attention_data=input_batch["inputs"],
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def sample_decode(

with child_context("sample_decode", module=self.decoder):
sample_decode_outputs: decoding.SampleOutputs = self.decoder.sample_decode(
prefix=input_batch["prefix"],
input_batch=input_batch,
max_sequence_length=max_decode_len,
num_decodes=num_decodes,
cross_attention_data=input_batch["inputs"],
Expand Down
8 changes: 4 additions & 4 deletions axlearn/audio/model_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from axlearn.common.base_encoder_decoder import BaseEncoderDecoderModel
from axlearn.common.config import REQUIRED, Required, config_class
from axlearn.common.module import Module
from axlearn.common.utils import Nested, Tensor
from axlearn.common.utils import Nested, Tensor, validate_contains_paths


class ASRModel(BaseEncoderDecoderModel):
Expand Down Expand Up @@ -51,7 +51,7 @@ def predict(self, input_batch: Nested[Tensor]) -> Nested[Tensor]:
Returns:
A dict containing logits. The shape of logits depend on the decoder.
"""
self._validate_input_batch(input_batch, ["source", "target", "target_labels"])
validate_contains_paths(input_batch, ["source", "target", "target_labels"])
# Encoder hidden states: [batch_size, source_len, dim].
encoder_output = self.encoder(**input_batch["source"])
logits = self.decoder.predict(
Expand Down Expand Up @@ -82,7 +82,7 @@ def forward(
aux_outputs: A dict containing auxiliary outputs if `return_aux=True`, otherwise an
empty dict.
"""
self._validate_input_batch(input_batch, ["source", "target", "target_labels"])
validate_contains_paths(input_batch, ["source", "target", "target_labels"])
# Encoder hidden states: [batch_size, source_len, dim].
encoder_output = self.encoder(**input_batch["source"])
loss, aux_outputs = self.decoder(
Expand Down Expand Up @@ -115,7 +115,7 @@ def beam_search_decode(
Returns:
Beam search decode outputs.
"""
self._validate_input_batch(input_batch, ["source/inputs", "source/paddings"])
validate_contains_paths(input_batch, ["source/inputs", "source/paddings"])
encoder_output = self.encoder(**input_batch["source"])
return self.decoder.beam_search_decode(
input_batch=dict(
Expand Down
3 changes: 2 additions & 1 deletion axlearn/audio/subsamplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from axlearn.common.base_layer import BaseLayer
from axlearn.common.config import REQUIRED, Required, config_class
from axlearn.common.layers import BaseNormalizationLayer, Conv2DWith1DPadding, get_activation_fn
from axlearn.common.convolution import Conv2DWith1DPadding
from axlearn.common.layers import BaseNormalizationLayer, get_activation_fn
from axlearn.common.module import Module
from axlearn.common.utils import Tensor

Expand Down
12 changes: 12 additions & 0 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
# A label added to the jobset to indicate job version.
BASTION_JOB_VERSION_LABEL = "bastion-job-version"

# The metadata.google.internal IP.
# https://cloud.google.com/compute/docs/troubleshooting/troubleshoot-metadata-server#failed-request
_METADATA_GOOGLE_INTERNAL_IP = "169.254.169.254"


class GCPJob(Job):
"""Base GCP Job definition."""
Expand Down Expand Up @@ -747,11 +751,19 @@ def _build_pod(self) -> Nested[Any]:
}
)

# Hardcode metadata.google.internal ip address to avoid transient DNS resolution issue.
metadata_host_alias = dict(
ip=_METADATA_GOOGLE_INTERNAL_IP,
hostnames=["metadata", "metadata.google.internal"],
)

spec = dict(
# NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity.
terminationGracePeriodSeconds=60,
# Fail if any pod fails, and allow retries to happen at JobSet level.
restartPolicy="Never",
# https://kubernetes.io/docs/tasks/network/customize-hosts-file-for-pods/#adding-additional-entries-with-hostaliases
hostAliases=[metadata_host_alias],
nodeSelector={
"cloud.google.com/gke-tpu-accelerator": system.gke_accelerator,
"cloud.google.com/gke-tpu-topology": system.topology,
Expand Down
12 changes: 12 additions & 0 deletions axlearn/cloud/gcp/job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from axlearn.cloud.gcp.config import gcp_settings
from axlearn.cloud.gcp.job import (
_MEMORY_REQUEST_PERCENTAGE,
_METADATA_GOOGLE_INTERNAL_IP,
BASTION_JOB_VERSION_LABEL,
CPUJob,
GCSFuseMount,
Expand Down Expand Up @@ -393,6 +394,17 @@ def test_build_pod(
node_selector = pod_spec["nodeSelector"]
annotations = pod["metadata"]["annotations"]
labels = pod["metadata"]["labels"]
host_alias = pod["spec"]["hostAliases"]

self.assertEqual(1, len(host_alias))
self.assertEqual(
dict(
ip=_METADATA_GOOGLE_INTERNAL_IP,
hostnames=["metadata", "metadata.google.internal"],
),
host_alias[0],
)

# The reservation should be used only if scheduled as tier 0.
if expect_reserved:
self.assertEqual(
Expand Down
17 changes: 9 additions & 8 deletions axlearn/cloud/gcp/tpu_health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
shortest timeout. Pairwise health check should have the longest timeout since different slices may
bring up their container at different times.
The main API is the `health_check` function, which is commonly enabled via context manager:
with health_check(spec, output_dir=...):
The main API is the `setup` function, which is commonly enabled via context manager:
```
with setup(spec):
# Initialize jax distributed.
```
"""

import os
import signal
import subprocess
Expand All @@ -31,8 +33,7 @@
from typing import Literal, Optional, Union

import tensorflow as tf
import tensorflow_io # pylint: disable=unused-import
from absl import logging
from absl import flags, logging

from axlearn.cloud.gcp import tpu_health_check_main

Expand Down Expand Up @@ -127,11 +128,11 @@ def _run_health_check_program(


@contextmanager
def health_check(check_spec: str, *, output_dir: str):
_pre_init_health_check(check_spec, output_dir=output_dir)
def setup(check_spec: str):
_pre_init_health_check(check_spec, output_dir=flags.FLAGS.trainer_dir)
yield
# Skip global health check if there's an exception.
global_health_check(check_spec, output_dir=output_dir)
global_health_check(check_spec, output_dir=flags.FLAGS.trainer_dir)


def _pre_init_health_check(check_spec: str, *, output_dir: str):
Expand Down
11 changes: 7 additions & 4 deletions axlearn/common/adapter_torch_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

"""Tests PyTorch adapter layers."""

# pylint: disable=too-many-lines
import itertools
from collections import OrderedDict
Expand Down Expand Up @@ -889,9 +890,11 @@ def test_transformer_embeddings_forward(
jax.random.PRNGKey(0),
state=axlearn_layer_state,
inputs=dict(
inputs=jnp.asarray(input_ids),
token_type_ids=axlearn_token_type_ids,
positions=axlearn_positions,
input_batch=dict(
inputs=jnp.asarray(input_ids),
token_type_ids=axlearn_token_type_ids,
positions=axlearn_positions,
),
),
is_training=False,
method="forward",
Expand Down Expand Up @@ -971,7 +974,7 @@ def test_decoder_inference(self):
axlearn_layer,
jax.random.PRNGKey(0),
state=axlearn_layer_state,
inputs=dict(input_ids=jnp.asarray(input_ids)),
inputs=dict(input_batch=dict(input_ids=jnp.asarray(input_ids))),
is_training=False,
method="forward",
)[0]
Expand Down
52 changes: 38 additions & 14 deletions axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def _num_replicas_per_shard(arr: Tensor) -> dict[tuple[_SliceTuple, ...], int]:
return dict(replica_count)


def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_ShardInfo]:
def _get_shard_infos(
arr_inp: Tensor, *, max_data_shard_degree: int, shard_threshold_bytes: int
) -> list[_ShardInfo]:
"""Returns a list of _ShardInfo for addressable shards that need to be saved.
If replica count for the shards are greater than 0, all replicas will save slices of the
Expand All @@ -84,11 +86,21 @@ def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_Sh
for shard in arr_inp.addressable_shards:
replica_count = replica_count_map[_slices_to_tuple(shard.index)]
assert replica_count > 0
shard_degree = (
min(replica_count, max_data_shard_degree)
if max_data_shard_degree > 0
else replica_count
)
should_skip = (
shard_degree == 1
or shard.data.nbytes < shard_threshold_bytes
or shard.replica_id >= shard_degree
)
for axis, size in enumerate(shard.data.shape):
# Find the first dim divisible by partial replication size.
if max_data_shard_degree == 1 or replica_count == 1 or size % replica_count != 0:
if should_skip or size % shard_degree != 0:
continue
part_size = size // replica_count
part_size = size // shard_degree
slice_obj = shard.index[axis]
assert slice_obj.step is None
start_offset = shard.replica_id * part_size
Expand All @@ -103,7 +115,7 @@ def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_Sh
+ (slice(slice_start + start_offset, slice_start + end_offset),)
+ shard.index[axis + 1 :],
(start_offset, end_offset, axis),
replica_count,
shard_degree,
)
)
break
Expand Down Expand Up @@ -181,7 +193,8 @@ async def _async_serialize(
d2h_future: futures.Future,
*,
limiter: Optional[serialization._LimitInFlightBytes] = None,
max_data_shard_degree: Optional[int] = None,
max_data_shard_degree: int,
shard_threshold_bytes: int,
):
"""Similar to `serialization.async_serialize`, but limiting peak host memory usage and sharding
along data-parallel axis.
Expand All @@ -195,7 +208,11 @@ async def _async_serialize(
Reference:
https://github.com/google/jax/blob/595a620804e810335a870e93975a78504b2e95e5/jax/experimental/array_serialization/serialization.py#L188
"""
shard_infos = _get_shard_infos(arr_inp, max_data_shard_degree=max_data_shard_degree)
shard_infos = _get_shard_infos(
arr_inp,
max_data_shard_degree=max_data_shard_degree,
shard_threshold_bytes=shard_threshold_bytes,
)
if not shard_infos:
d2h_future.set_result(shard_infos)
return
Expand Down Expand Up @@ -261,7 +278,8 @@ async def _run_serializer(
d2h_futures: list[futures.Future],
*,
max_concurrent_bytes: Optional[int] = None,
max_data_shard_degree: Optional[int] = None,
max_data_shard_degree: int,
shard_threshold_bytes: int,
):
"""Asynchronously serializes a list of tensors with _async_serialize."""
# We add 1 because LimitInFlightBytes expects a limit strictly greater than any request.
Expand All @@ -274,7 +292,10 @@ async def _run_serializer(
# pylint: enable=protected-access
future_writer = jax.tree.map(
functools.partial(
_async_serialize, limiter=limiter, max_data_shard_degree=max_data_shard_degree
_async_serialize,
limiter=limiter,
max_data_shard_degree=max_data_shard_degree,
shard_threshold_bytes=shard_threshold_bytes,
),
arrays,
tensorstore_specs,
Expand Down Expand Up @@ -385,7 +406,9 @@ class BoundedDataShardedAsyncCheckpointManager(serialization.GlobalAsyncCheckpoi
max_concurrent_gb: Max concurrent shards (in GB) to write.
max_data_shard_degree: Max sharding degree of model weights along data-parallel axis.
`None` and `1` means no sharding. `-1` means fully shard along data-parallel
replicas. `>1` means custom sharding degree (currently not implemented).
replicas. `>1` means custom sharding degree and should almost always be a power of 2.
shard_threshold_bytes: Threshold for a array shard to be data-sharded. A value of None
or <= 0 means always data-shard according to max_data_shard_degree.
timeout_secs: Barrier timeout in seconds.
"""

Expand All @@ -395,6 +418,7 @@ def __init__(
max_concurrent_gb: Optional[int] = None,
timeout_secs: int = 300,
max_data_shard_degree: Optional[int] = None,
shard_threshold_bytes: Optional[int] = None,
):
super().__init__(timeout_secs)
self._logged_spec = False
Expand All @@ -406,11 +430,10 @@ def __init__(
raise ValueError("max_concurrent_gb must be strictly positive.")
self._max_concurrent_bytes = int(max_concurrent_gb * 10**9)

self._max_data_shard_degree = max_data_shard_degree or 1
if self._max_data_shard_degree not in (1, -1):
raise NotImplementedError(
"max_data_shard_degree is not implemented for values other than 1 and -1"
)
self._max_data_shard_degree = 1 if max_data_shard_degree is None else max_data_shard_degree
if self._max_data_shard_degree == 0:
raise NotImplementedError("max_data_shard_degree cannot be 0.")
self._shard_threshold_bytes = shard_threshold_bytes or 0

def serialize(
self,
Expand Down Expand Up @@ -457,6 +480,7 @@ def serialize(
d2h_futures,
max_concurrent_bytes=max_concurrent_bytes,
max_data_shard_degree=self._max_data_shard_degree,
shard_threshold_bytes=self._shard_threshold_bytes,
)
)
]
Expand Down
Loading

0 comments on commit 1469a35

Please sign in to comment.