Skip to content

Commit

Permalink
clean up axis hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Sep 11, 2024
1 parent f948154 commit d0fdfbf
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 63 deletions.
4 changes: 2 additions & 2 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _add_axis(x: tp.Any):
sharding.insert(index, axis_name)
x.sharding = tuple(sharding) # type: ignore

x.add_axis(axis_name, index)
x.add_axis(index, axis_name)
return x

return jax.tree.map(
Expand All @@ -61,7 +61,7 @@ def _remove_axis(x: tp.Any):
sharding = list(x.sharding)
assert sharding.pop(index) == axis_name
x.sharding = tuple(sharding)
x.remove_axis(axis_name, index)
x.remove_axis(index, axis_name)
return x

return jax.tree.map(
Expand Down
94 changes: 33 additions & 61 deletions flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
CreateValueHook = tp.Callable[['Variable[A]', A], A]
AxisName = str
AxisIndex = int
AddAxisHook = tp.Callable[[V, AxisName, AxisIndex], None]
RemoveAxisHook = tp.Callable[[V, AxisName, AxisIndex], None]
AddAxisHook = tp.Callable[[V, AxisIndex, AxisName], None]
RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName], None]

VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}

Expand Down Expand Up @@ -150,67 +150,43 @@ def __init__(
**metadata: tp.Any,
):
vars(self)['_trace_state'] = tracers.TraceState()
if set_value_hooks:
if callable(set_value_hooks):
set_value_hooks = (set_value_hooks,)
else:
set_value_hooks = tuple(set_value_hooks)
if callable(set_value_hooks):
set_value_hooks = (set_value_hooks,)
else:
set_value_hooks = ()
if get_value_hooks:
if callable(get_value_hooks):
get_value_hooks = (get_value_hooks,)
else:
get_value_hooks = tuple(get_value_hooks)
set_value_hooks = tuple(set_value_hooks)

if callable(get_value_hooks):
get_value_hooks = (get_value_hooks,)
else:
get_value_hooks = ()
get_value_hooks = tuple(get_value_hooks)

if create_value_hooks:
if callable(create_value_hooks):
create_value_hooks = (create_value_hooks,)
else:
create_value_hooks = tuple(create_value_hooks)
if callable(create_value_hooks):
create_value_hooks = (create_value_hooks,)
else:
create_value_hooks = ()
create_value_hooks = tuple(create_value_hooks)

if add_axis_hooks:
if callable(add_axis_hooks):
add_axis_hooks = (add_axis_hooks,)
else:
add_axis_hooks = tuple(add_axis_hooks)
if callable(add_axis_hooks):
add_axis_hooks = (add_axis_hooks,)
else:
add_axis_hooks = ()
add_axis_hooks = tuple(add_axis_hooks)

if remove_axis_hooks:
if callable(remove_axis_hooks):
remove_axis_hooks = (remove_axis_hooks,)
else:
remove_axis_hooks = tuple(remove_axis_hooks)
if callable(remove_axis_hooks):
remove_axis_hooks = (remove_axis_hooks,)
else:
remove_axis_hooks = ()
remove_axis_hooks = tuple(remove_axis_hooks)

if isinstance(value, VariableMetadata):
value_metadata = dict(value.metadata)
if set_value_hooks and value.set_value_hooks:
if value.set_value_hooks:
set_value_hooks = set_value_hooks + value.set_value_hooks
elif value.set_value_hooks:
set_value_hooks = value.set_value_hooks
if get_value_hooks and value.get_value_hooks:
if value.get_value_hooks:
get_value_hooks = get_value_hooks + value.get_value_hooks
elif value.get_value_hooks:
get_value_hooks = value.get_value_hooks
if create_value_hooks and value.create_value_hooks:
if value.create_value_hooks:
create_value_hooks = create_value_hooks + value.create_value_hooks
elif value.create_value_hooks:
create_value_hooks = value.create_value_hooks
if add_axis_hooks and value.add_axis_hooks:
if value.add_axis_hooks:
add_axis_hooks = add_axis_hooks + value.add_axis_hooks
elif value.add_axis_hooks:
add_axis_hooks = value.add_axis_hooks
if remove_axis_hooks and value.remove_axis_hooks:
if value.remove_axis_hooks:
remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks
elif value.remove_axis_hooks:
remove_axis_hooks = value.remove_axis_hooks

metadata.update(value_metadata)
value = tp.cast(A, value.raw_value)
Expand Down Expand Up @@ -318,13 +294,13 @@ def create_value(self, value: A):
value = hook(self, value)
return value

def add_axis(self, axis_name: AxisName, axis_index: AxisIndex):
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
for hook in self.add_axis_hooks:
hook(self, axis_name, axis_index)
hook(self, axis_index, axis_name)

def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex):
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
for hook in self.remove_axis_hooks:
hook(self, axis_name, axis_index)
hook(self, axis_index, axis_name)

def __eq__(self, other: object) -> bool:
return type(self) is type(other) and vars(other) == vars(self)
Expand Down Expand Up @@ -418,11 +394,11 @@ def on_set_value(self, value: A) -> A: ...
def on_create_value(self, value: A) -> A: ...

def on_add_axis(
self: V, axis_name: AxisName, axis_index: AxisIndex
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...

def on_remove_axis(
self: V, axis_name: AxisName, axis_index: AxisIndex
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...

def __jax_array__(self):
Expand Down Expand Up @@ -870,17 +846,13 @@ def get_metadata(self) -> dict[str, tp.Any]:
del metadata['value']
return metadata

def add_axis(self, axis_name: AxisName, axis_index: AxisIndex):
if not hasattr(self, 'add_axis_hooks'):
raise ValueError(f'No add_axis_hooks found for VariableState: {self}')
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
for hook in self.add_axis_hooks:
hook(self, axis_name, axis_index)
hook(self, axis_index, axis_name)

def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex):
if not hasattr(self, 'remove_axis_hooks'):
raise ValueError(f'No remove_axis_hooks found for VariableState: {self}')
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
for hook in self.remove_axis_hooks:
hook(self, axis_name, axis_index)
hook(self, axis_index, axis_name)


def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
Expand Down
59 changes: 59 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,65 @@ def __call__(self, x):
assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col')
assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col')

def test_add_remove_axis_in_transform(self):
test = self
kadds, kremoves, badds, bremoves = [], [], [], []
class MLP(nnx.Module):

@nnx.split_rngs(splits=5)
@nnx.vmap(
in_axes=(0, 0),
transform_metadata={nnx.PARTITION_NAME: 'layers'},
)
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(
3,
3,
kernel_init=nnx.with_metadata(
nnx.initializers.lecun_normal(), sharding=('din', 'dout'),
add_axis_hooks=lambda _, idx, name: kadds.append((idx, name)),
remove_axis_hooks=lambda _, idx, name: kremoves.append((idx, name)),
),
bias_init=nnx.with_metadata(
nnx.initializers.zeros_init(), # no sharding annotation here!
add_axis_hooks=lambda _, idx, name: badds.append((idx, name)),
remove_axis_hooks=lambda _, idx, name: bremoves.append((idx, name)),
),
rngs=rngs,
)

@nnx.scan(
in_axes=(0, nnx.Carry),
transform_metadata={nnx.PARTITION_NAME: 'layers'}
)
def __call__(self, x: jax.Array):
x = self.linear(x)
# test sharding layer axes is not present inside scan
test.assertEqual(self.linear.kernel.shape, (3, 3))
test.assertEqual(self.linear.kernel.sharding, ('din', 'dout'))
# at least a remove_axis was already called to remove the layer axis
test.assertEqual(kremoves[-1], (0, 'layers'))
test.assertEqual(bremoves[-1], (0, 'layers'))
return x, None

m = MLP(rngs=nnx.Rngs(0))
self.assertEqual(m.linear.kernel.shape, (5, 3, 3))
self.assertEqual(m.linear.kernel.sharding, ('layers', 'din', 'dout'))
self.assertEqual(m.linear.bias.shape, (5, 3))
self.assertEqual(m.linear.bias.sharding, ('layers', 'dout'))
# One add_axis called to add the `nnx.vmap` dimension
self.assertEqual(kadds, [(0, 'layers')])
self.assertEqual(kremoves, [])
self.assertEqual(badds, [(0, 'layers')])
self.assertEqual(bremoves, [])

# One remove_axis and one add_axis called when in and out of `nnx.scan`
y = m(jnp.ones((5, 3)))
self.assertEqual(kadds, [(0, 'layers'), (0, 'layers')])
self.assertEqual(kremoves, [(0, 'layers')])
self.assertEqual(badds, [(0, 'layers'), (0, 'layers')])
self.assertEqual(bremoves, [(0, 'layers')])


if __name__ == '__main__':
absltest.main()
Expand Down

0 comments on commit d0fdfbf

Please sign in to comment.