Skip to content

Commit

Permalink
Handle config aliasing in make_dot_general.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560847701
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Aug 28, 2023
1 parent cf54757 commit 59bd082
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _maybe_inv(x):
return 1.0 / x


def _make_dot_general_raw(cfg: config.DotGeneralRaw):
def _make_dot_general_raw(gcfg: config.DotGeneralRaw):
"""Makes quantized lax.dot_general replacement."""

def my_dot_general(
Expand All @@ -307,6 +307,9 @@ def my_dot_general(
context,
):
"""Creates a fake_quant function."""
# We need to copy because we modify cfg to populate some defaults.
cfg = copy.deepcopy(gcfg)

# TODO(lew):
# - Use qx.value with the int type.
# - Handle qx.value with the int type in an optimized way.
Expand Down

0 comments on commit 59bd082

Please sign in to comment.