Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up axis hooks in nnx.Variable #4189

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _add_axis(x: tp.Any):
sharding.insert(index, axis_name)
x.sharding = tuple(sharding) # type: ignore

x.add_axis(axis_name, index)
x.add_axis(index, axis_name)
return x

return jax.tree.map(
Expand All @@ -61,7 +61,7 @@ def _remove_axis(x: tp.Any):
sharding = list(x.sharding)
assert sharding.pop(index) == axis_name
x.sharding = tuple(sharding)
x.remove_axis(axis_name, index)
x.remove_axis(index, axis_name)
return x

return jax.tree.map(
Expand Down
94 changes: 33 additions & 61 deletions flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
CreateValueHook = tp.Callable[['Variable[A]', A], A]
AxisName = str
AxisIndex = int
AddAxisHook = tp.Callable[[V, AxisName, AxisIndex], None]
RemoveAxisHook = tp.Callable[[V, AxisName, AxisIndex], None]
AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]

VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}

Expand Down Expand Up @@ -150,67 +150,43 @@ def __init__(
**metadata: tp.Any,
):
vars(self)['_trace_state'] = tracers.TraceState()
if set_value_hooks:
if callable(set_value_hooks):
set_value_hooks = (set_value_hooks,)
else:
set_value_hooks = tuple(set_value_hooks)
if callable(set_value_hooks):
set_value_hooks = (set_value_hooks,)
else:
set_value_hooks = ()
if get_value_hooks:
if callable(get_value_hooks):
get_value_hooks = (get_value_hooks,)
else:
get_value_hooks = tuple(get_value_hooks)
set_value_hooks = tuple(set_value_hooks)

if callable(get_value_hooks):
get_value_hooks = (get_value_hooks,)
else:
get_value_hooks = ()
get_value_hooks = tuple(get_value_hooks)

if create_value_hooks:
if callable(create_value_hooks):
create_value_hooks = (create_value_hooks,)
else:
create_value_hooks = tuple(create_value_hooks)
if callable(create_value_hooks):
create_value_hooks = (create_value_hooks,)
else:
create_value_hooks = ()
create_value_hooks = tuple(create_value_hooks)

if add_axis_hooks:
if callable(add_axis_hooks):
add_axis_hooks = (add_axis_hooks,)
else:
add_axis_hooks = tuple(add_axis_hooks)
if callable(add_axis_hooks):
add_axis_hooks = (add_axis_hooks,)
else:
add_axis_hooks = ()
add_axis_hooks = tuple(add_axis_hooks)

if remove_axis_hooks:
if callable(remove_axis_hooks):
remove_axis_hooks = (remove_axis_hooks,)
else:
remove_axis_hooks = tuple(remove_axis_hooks)
if callable(remove_axis_hooks):
remove_axis_hooks = (remove_axis_hooks,)
else:
remove_axis_hooks = ()
remove_axis_hooks = tuple(remove_axis_hooks)

if isinstance(value, VariableMetadata):
value_metadata = dict(value.metadata)
if set_value_hooks and value.set_value_hooks:
if value.set_value_hooks:
set_value_hooks = set_value_hooks + value.set_value_hooks
elif value.set_value_hooks:
set_value_hooks = value.set_value_hooks
if get_value_hooks and value.get_value_hooks:
if value.get_value_hooks:
get_value_hooks = get_value_hooks + value.get_value_hooks
elif value.get_value_hooks:
get_value_hooks = value.get_value_hooks
if create_value_hooks and value.create_value_hooks:
if value.create_value_hooks:
create_value_hooks = create_value_hooks + value.create_value_hooks
elif value.create_value_hooks:
create_value_hooks = value.create_value_hooks
if add_axis_hooks and value.add_axis_hooks:
if value.add_axis_hooks:
add_axis_hooks = add_axis_hooks + value.add_axis_hooks
elif value.add_axis_hooks:
add_axis_hooks = value.add_axis_hooks
if remove_axis_hooks and value.remove_axis_hooks:
if value.remove_axis_hooks:
remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks
elif value.remove_axis_hooks:
remove_axis_hooks = value.remove_axis_hooks

metadata.update(value_metadata)
value = tp.cast(A, value.raw_value)
Expand Down Expand Up @@ -318,13 +294,13 @@ def create_value(self, value: A):
value = hook(self, value)
return value

def add_axis(self, axis_name: AxisName, axis_index: AxisIndex):
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
for hook in self.add_axis_hooks:
hook(self, axis_name, axis_index)
hook(self, axis_index, axis_name)

def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex):
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
for hook in self.remove_axis_hooks:
hook(self, axis_name, axis_index)
hook(self, axis_index, axis_name)

def __eq__(self, other: object) -> bool:
return type(self) is type(other) and vars(other) == vars(self)
Expand Down Expand Up @@ -418,11 +394,11 @@ def on_set_value(self, value: A) -> A: ...
def on_create_value(self, value: A) -> A: ...

def on_add_axis(
self: V, axis_name: AxisName, axis_index: AxisIndex
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...

def on_remove_axis(
self: V, axis_name: AxisName, axis_index: AxisIndex
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...

def __jax_array__(self):
Expand Down Expand Up @@ -870,17 +846,13 @@ def get_metadata(self) -> dict[str, tp.Any]:
del metadata['value']
return metadata

def add_axis(self, axis_name: AxisName, axis_index: AxisIndex):
if not hasattr(self, 'add_axis_hooks'):
raise ValueError(f'No add_axis_hooks found for VariableState: {self}')
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
for hook in self.add_axis_hooks:
hook(self, axis_name, axis_index)
hook(self, axis_index, axis_name)

def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex):
if not hasattr(self, 'remove_axis_hooks'):
raise ValueError(f'No remove_axis_hooks found for VariableState: {self}')
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
for hook in self.remove_axis_hooks:
hook(self, axis_name, axis_index)
hook(self, axis_index, axis_name)


def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
Expand Down
58 changes: 58 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,64 @@ def __call__(self, x):
assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col')
assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col')

def test_add_remove_axis_in_transform(self):
test = self
kadds, kremoves, badds, bremoves = [], [], [], []
class MLP(nnx.Module):

@nnx.split_rngs(splits=5)
@nnx.vmap(
in_axes=(0, 0),
transform_metadata={nnx.PARTITION_NAME: 'layers'},
)
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(
3,
3,
kernel_init=nnx.with_metadata(
nnx.initializers.lecun_normal(), sharding=('din', 'dout'),
add_axis_hooks=lambda _, idx, name: kadds.append((idx, name)),
remove_axis_hooks=lambda _, idx, name: kremoves.append((idx, name)),
),
bias_init=nnx.with_metadata(
nnx.initializers.zeros_init(), # no sharding annotation here!
add_axis_hooks=lambda _, idx, name: badds.append((idx, name)),
remove_axis_hooks=lambda _, idx, name: bremoves.append((idx, name)),
),
rngs=rngs,
)

@nnx.scan(
in_axes=(0, nnx.Carry),
transform_metadata={nnx.PARTITION_NAME: 'layers'}
)
def __call__(self, x: jax.Array):
x = self.linear(x)
# test sharding layer axes is not present inside scan
test.assertEqual(self.linear.kernel.shape, (3, 3))
test.assertEqual(self.linear.kernel.sharding, ('din', 'dout'))
# at least a remove_axis was already called to remove the layer axis
test.assertEqual(kremoves[-1], (0, 'layers'))
test.assertEqual(bremoves[-1], (0, 'layers'))
return x, None

m = MLP(rngs=nnx.Rngs(0))
self.assertEqual(m.linear.kernel.shape, (5, 3, 3))
self.assertEqual(m.linear.kernel.sharding, ('layers', 'din', 'dout'))
self.assertEqual(m.linear.bias.shape, (5, 3))
# One add_axis called to add the `nnx.vmap` dimension
self.assertEqual(kadds, [(0, 'layers')])
self.assertEqual(kremoves, [])
self.assertEqual(badds, [(0, 'layers')])
self.assertEqual(bremoves, [])

# One remove_axis and one add_axis called when in and out of `nnx.scan`
y = m(jnp.ones((5, 3)))
self.assertEqual(kadds, [(0, 'layers'), (0, 'layers')])
self.assertEqual(kremoves, [(0, 'layers')])
self.assertEqual(badds, [(0, 'layers'), (0, 'layers')])
self.assertEqual(bremoves, [(0, 'layers')])


if __name__ == '__main__':
absltest.main()
Expand Down
Loading