Skip to content

Commit

Permalink
Add back bias metadata wrapper code that was accidentally removed.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681103256
  • Loading branch information
phoenix-meadowlark authored and copybara-github committed Oct 1, 2024
1 parent 7f2f6a4 commit 9acd003
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions aqt/jax/v2/flax/aqt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def _freezer_qtensor_init_wrapper(
scale_non_shard_axis_all = list(range(qt.ndim))
scale_non_shard_axis_contracting = list(contracting_axis)

def _get_singleton_axes(x: jnp.ndarray) -> list[utils.AxisIdx]:
return [axis for axis, dim in enumerate(x.shape) if dim == 1]

qt = qt.replace(
qvalue=axis_metadata_wrapper(qt.qvalue, tile_map, []),
scale=jax.tree.map(
Expand All @@ -88,6 +91,13 @@ def _freezer_qtensor_init_wrapper(
),
qt.scale_t,
),
# Set the non-sharding axes for bias to the singleton dimensions.
bias=jax.tree.map(
lambda x: axis_metadata_wrapper(
x, tile_map, _get_singleton_axes(x)
),
qt.bias,
),
)
return qt

Expand Down

0 comments on commit 9acd003

Please sign in to comment.