From 59bd082ffa92603db1b44cb0f765aed5bce16699 Mon Sep 17 00:00:00 2001 From: Cerebra Catalyst Team Date: Mon, 28 Aug 2023 16:17:51 -0700 Subject: [PATCH] Handle config aliasing in make_dot_general. PiperOrigin-RevId: 560847701 --- aqt/jax/v2/aqt_dot_general.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index b441dbb8..f3321407 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -297,7 +297,7 @@ def _maybe_inv(x): return 1.0 / x -def _make_dot_general_raw(cfg: config.DotGeneralRaw): +def _make_dot_general_raw(gcfg: config.DotGeneralRaw): """Makes quantized lax.dot_general replacement.""" def my_dot_general( @@ -307,6 +307,9 @@ def my_dot_general( context, ): """Creates a fake_quant function.""" + # We need to copy because we modify cfg to populate some defaults. + cfg = copy.deepcopy(gcfg) + # TODO(lew): # - Use qx.value with the int type. # - Handle qx.value with the int type in an optimized way.