-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Input batch sharding strategy BATCH #884
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,6 @@ | |
host_to_global_device_array, | ||
) | ||
|
||
|
||
def is_supported( | ||
platform: str, | ||
mesh_shape: tuple[int, int], | ||
|
@@ -37,16 +36,15 @@ def is_supported( | |
) | ||
) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you using a different pylint style? Those lines should not be removed. |
||
class HostArrayTest(TestCase): | ||
@parameterized.parameters( | ||
filter( | ||
lambda params: is_supported(*params), | ||
itertools.product( | ||
("cpu", "tpu"), # platform, | ||
((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape | ||
((1, 1), (4, 1), (2, 2), (8, 1), (4, 2), (16, 4)), # mesh_shape | ||
(1, 16), # global_batch_size | ||
(DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition | ||
(DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType,BATCH), # data_partition | ||
), | ||
) | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -591,14 +591,17 @@ class DataPartitionType(Enum): | |
FULL = "full" | ||
# Data are fully replicated across all devices. | ||
REPLICATED = "replicated" | ||
# Data are partitioned across batch axis only. | ||
BATCH = "batch" | ||
|
||
|
||
def data_partition_type_to_spec(partition: DataPartitionType) -> PartitionSpec: | ||
def data_partition_type_to_spec(partition: DataPartitionType, * , batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp")) -> PartitionSpec: | ||
"""Returns a PartitionSpec for the given partition type.""" | ||
if partition == DataPartitionType.FULL: | ||
return input_partition_spec() | ||
elif partition == DataPartitionType.REPLICATED: | ||
return None | ||
elif partition == DataPartitionType.BATCH: | ||
return PartitionSpec(batch_axis_names) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of directly assigning PartitionSpec, maybe extend |
||
else: | ||
raise NotImplementedError(f"Unsupported partition: {partition}") | ||
|
||
|
@@ -607,6 +610,7 @@ def host_to_global_device_array( | |
host_arrays: Nested[Union[np.ndarray, Tensor]], | ||
*, | ||
partition: DataPartitionType = DataPartitionType.FULL, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @markblee plans to remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks that sounds promising. Hello @markblee, let me know if this PR is needed till you make your changes, or if you have your design in mind I can reshape the PR to make it compatible with your design. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, support for
This brings the API closer to jax native, but may require more changes internally. |
||
batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"), | ||
) -> NestedTensor: | ||
"""Converts the given host device arrays to global device arrays. | ||
|
||
|
@@ -625,7 +629,7 @@ def host_to_global_device_array( | |
NotImplementedError: if the given `partition` type is not supported. | ||
""" | ||
mesh = thread_resources.env.physical_mesh | ||
partition_spec = data_partition_type_to_spec(partition) | ||
partition_spec = data_partition_type_to_spec(partition, batch_axis_names=batch_axis_names) | ||
partition_specs = complete_partition_spec_tree( | ||
jax.tree_util.tree_structure(host_arrays), partition_spec | ||
) | ||
|
@@ -636,6 +640,8 @@ def make_gda(x, partition_spec): | |
global_shape = (x.shape[0] * process_count, *x.shape[1:]) | ||
elif partition == DataPartitionType.REPLICATED: | ||
global_shape = (x.shape[0], *x.shape[1:]) | ||
elif partition == DataPartitionType.BATCH: | ||
global_shape = (x.shape[0] * process_count, *x.shape[1:]) | ||
else: | ||
raise NotImplementedError(f"Unsupported partition: {partition}") | ||
return jax.make_array_from_process_local_data( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ | |
) | ||
from axlearn.common.trainer import SpmdTrainer | ||
from axlearn.common.utils import ( | ||
DataPartitionType, | ||
PHYSICAL_TO_LOGICAL_DISPATCH_KEY, | ||
HybridMeshShape, | ||
MeshShape, | ||
|
@@ -1701,6 +1702,31 @@ def test_length(self): | |
class HostToGlobalArrayTest(TestCase): | ||
"""Tests host_to_global_device_array.""" | ||
|
||
@pytest.mark.neuron | ||
def test_partition_batch(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we able to pass the |
||
"""Test a case where each process produces a slice.""" | ||
device_count = jax.device_count() | ||
process_count = jax.process_count() | ||
print(f"{device_count=}, {process_count=}") | ||
assert device_count > 1 | ||
|
||
global_shape = (device_count // 2, 1) | ||
assert global_shape[0] % process_count == 0 | ||
per_feed_size = global_shape[0] // process_count | ||
feed_index = jax.process_index() | ||
|
||
with jax.sharding.Mesh(np.array(jax.devices()).reshape(device_count // 2, 2), ("x", "y")): | ||
start = feed_index * per_feed_size | ||
local_x = jnp.arange(start, start + per_feed_size)[:, None] | ||
|
||
# Construct global array. | ||
global_x = host_to_global_device_array(local_x, partition=DataPartitionType.BATCH, batch_axis_names="x") | ||
|
||
# Compare against expected. | ||
expected = jnp.arange(global_shape[0])[:, None] | ||
self.assertEqual(jnp.mean(expected), jnp.mean(global_x)) | ||
self.assertNestedEqual(expected, replicate_to_local_data(global_x)) | ||
|
||
@pytest.mark.tpu | ||
def test_partition_full(self): | ||
"""Test a case where each process produces a slice.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ | |
MeshShapeModifier, | ||
RematSpecModifier, | ||
) | ||
from axlearn.common.utils import extended_checkpoint_policies | ||
from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies | ||
from axlearn.experiments.text.gpt.common import ( | ||
STEP_DTYPE, | ||
SourceBuilder, | ||
|
@@ -423,6 +423,7 @@ def get_trainer_kwargs( | |
raise NotImplementedError(f"Unknown model size {model_size}.") | ||
model_kwargs = trainer_kwargs.pop("model_kwargs") | ||
model_kwargs.setdefault("vocab_size", vocab_size) | ||
trainer_kwargs["input_partition_type"] = None if backend != "neuron" else DataPartitionType.BATCH | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please go with a configmodifier instead of hardcode in the code, otherwise it becomes hard for people to debug |
||
trainer_kwargs["model_cfg"] = model_config(**model_kwargs) | ||
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( | ||
max_step=trainer_kwargs["max_step"], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you run through pylint? this line seems quite long.