Skip to content

Commit

Permalink
modify layer_stack transparency map
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551653859
  • Loading branch information
Haiku Contributor authored and copybara-github committed Jul 27, 2023
1 parent 6526bc8 commit 308434e
Showing 1 changed file with 10 additions and 28 deletions.
38 changes: 10 additions & 28 deletions haiku/_src/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,16 @@ def _split_params(
name_map: LayerStackTransparencyMapping,
) -> base.Params:
"""Splits the stacked parameters."""

def _split(x):
return [jnp.squeeze(s, axis=0) for s in jnp.split(x, x.shape[0], axis=0)]

params = {}
for mod_name, mod_params in stacked_params.items():
split_mod_params = {k: _split(v) for k, v in mod_params.items()}
for i in range(num_layers):
new_mod_name = name_map.stacked_to_flat(mod_name, i)
if new_mod_name in params:
raise ValueError(
f"Found conflicting unstacked module name for {mod_name} at"
f" {new_mod_name}."
)
params[new_mod_name] = {k: v[i] for k, v in split_mod_params.items()}

params[new_mod_name] = jax.tree_map(lambda x: x[i], mod_params) # pylint:disable=cell-var-from-loop
return params


Expand All @@ -114,32 +108,20 @@ def _stack_params(
name_map: LayerStackTransparencyMapping,
) -> base.Params:
"""Stacks the split parameters."""
params = {}
make_empty_param_stack = lambda: ([None] * num_layers)

# Construct a separate tree for each loop iteration, which we will then
# multimap over in a call to jnp.stack. This formulation preserves custom
# pytree node types.
param_trees = [{} for _ in range(num_layers)]
for mod_name, mod_params in split_params.items():
stacked_name_idx = name_map.flat_to_stacked(mod_name)
# If the transparency map returns None, this param is not part of the stack.
if stacked_name_idx is None:
continue
stacked_mod_name, idx = stacked_name_idx
if stacked_mod_name not in params:
params[stacked_mod_name] = collections.defaultdict(make_empty_param_stack)

for k, v in mod_params.items():
if params[stacked_mod_name][k][idx] is not None:
raise ValueError(
f"Found conflicting values for param {stacked_mod_name}/{k} at"
f" index {idx}."
)
params[stacked_mod_name][k][idx] = v

for mod_name, mod_params in params.items():
for k, v in mod_params.items():
if None in v:
raise ValueError(f"Couldn't find all params for {mod_name}/{k}: {v}")
mod_params[k] = jnp.stack(v, axis=0)

return params
if stacked_mod_name not in param_trees[idx]:
param_trees[idx][stacked_mod_name] = {}
param_trees[idx][stacked_mod_name].update(mod_params)
return jax.tree_map(lambda *args: jnp.stack(args, axis=0), *param_trees)


class _LayerStack:
Expand Down

0 comments on commit 308434e

Please sign in to comment.