From 029130308a4f304a5b8c7ff4c0c006b55d556874 Mon Sep 17 00:00:00 2001 From: kaixih Date: Tue, 16 Jan 2024 19:00:00 +0000 Subject: [PATCH 1/3] Use custom grad accum for fp8 meta params --- flax/linen/fp8_ops.py | 111 ++++++++++++++++++++++++++++++-------- tests/linen/linen_test.py | 47 ++++++++++++++-- 2 files changed, 131 insertions(+), 27 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index c84aacfe73..0f9712cd14 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -12,16 +12,75 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses +import numpy as np import warnings from functools import partial from jax import custom_jvp, custom_vjp, lax, random from jax import numpy as jnp +from jax._src import core +from jax._src import dtypes from flax.linen import initializers, module OVERWRITE_WITH_GRADIENT = '_overwrite_with_gradient' +# Define a custom dtype for FP8 meta params. +class Fp8MetaTyRules: + # tell JAX how to lower this dtype to an HLO dtype + @staticmethod + def physical_element_aval(dtype) -> core.ShapedArray: + return core.ShapedArray((), dtype.float_dtype) + + # allow conversions to and from the corresponding float type + @staticmethod + def convert_from(fp8_meta_dtype, other_dtype) -> bool: + return fp8_meta_dtype.float_dtype == other_dtype + + @staticmethod + def convert_to(other_dtype, fp8_meta_dtype) -> bool: + return fp8_meta_dtype.float_dtype == other_dtype + + # define how autodiff should accumulate these values + @staticmethod + def add(dt, x, y): + from_fp8_meta = partial(lax.convert_element_type, new_dtype=dt.float_dtype) + to_fp8_meta = partial(lax.convert_element_type, new_dtype=dt) + return to_fp8_meta(lax.max(from_fp8_meta(x), from_fp8_meta(y))) + + @staticmethod + def zero(dt): + neginf = np.array(-np.inf if dtypes.supports_inf(dt.float_dtype) + else dtypes.finfo(dt.float_dtype).min, dt.float_dtype) + return lax.convert_element_type(neginf, dt) + + @staticmethod + def tangent_dtype(dtype): + return dtype + + # NOTE: by skipping some rules, this dtype can only be used underneath jit + @staticmethod + def global_sharded_result_handler(aval, sharding, committed, is_from_xla): + raise NotImplementedError("convert back under the jit") + + +# class to use as second argument to jax.dtypes.issubdtype +class fp8_meta_dtype(dtypes.extended): pass + +# parameterized datatype for use in e.g. lax.convert_element_type +@dataclasses.dataclass(frozen=True) +class fp8_meta_dtype_wrapper(dtypes.ExtendedDType): + float_dtype: dtypes.DType + _rules: type = Fp8MetaTyRules + type: type = fp8_meta_dtype + + def __repr__(self) -> str: + nbits = dtypes.finfo(self.float_dtype).bits + return f'fp8_meta{nbits}' + name = property(__repr__) + +fm32 = fp8_meta_dtype_wrapper(jnp.float32) def get_fp8_max(fp8_dtype, out_dtype): assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2) @@ -60,21 +119,29 @@ def compute_scale(amax, scale, fp8_max, margin=0): return 1.0 / sf -def compute_scale_and_amax_history(x, q_dtype, scale, amax_history): - dtype_max = get_fp8_max(q_dtype, jnp.float32) - amax_update = jnp.max(jnp.abs(x)).astype(scale.dtype) +def compute_amax_history(x, amax_history): + amax_update = jnp.max(jnp.abs(x)).astype(amax_history.dtype) new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update) - amax_from_history = jnp.max(new_history, axis=0) + return new_history + + +def qdq_and_return(x, q_dtype, sf_fm32, ah_fm32, compute_dtype): + # convert fm32->f32 so we can do math + amax_history = lax.convert_element_type(ah_fm32, jnp.float32) + scale = lax.convert_element_type(sf_fm32, jnp.float32) + + dtype_max = get_fp8_max(q_dtype, jnp.float32) + amax_from_history = jnp.max(amax_history, axis=0) new_scale = compute_scale(amax_from_history, scale, dtype_max) - return new_scale, new_history + qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) -def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): - qx = quantize_dequantize(x, q_dtype, scale, compute_dtype) - new_scale, new_history = compute_scale_and_amax_history( - x, q_dtype, scale, amax_history - ) - return qx, new_scale, new_history + new_history = compute_amax_history(x, amax_history) + + # convert f32->fm32 so the autodiff system accumulates fp8 meta correctly + new_ah_fm32 = lax.convert_element_type(new_history, fm32) + new_sf_fm32 = lax.convert_element_type(new_scale, fm32) + return qx, new_sf_fm32, new_ah_fm32 @partial(custom_vjp, nondiff_argnums=(0,)) @@ -202,18 +269,18 @@ def __call__(self, *args, **kwargs): comp_dtype = k.dtype x = jnp.asarray(x, comp_dtype) - x_qdq = in_qdq( - comp_dtype, x, self.input_scale.value, self.input_amax_history.value - ) - k_qdq = in_qdq( - comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value + x_sf_fm32 = lax.convert_element_type(self.input_scale.value, fm32) + x_ah_fm32 = lax.convert_element_type(self.input_amax_history.value, fm32) + k_sf_fm32 = lax.convert_element_type(self.kernel_scale.value, fm32) + k_ah_fm32 = lax.convert_element_type(self.kernel_amax_history.value, fm32) + g_sf_fm32 = lax.convert_element_type(self.output_grad_scale.value, fm32) + g_ah_fm32 = lax.convert_element_type( + self.output_grad_amax_history.value, fm32 ) + + x_qdq = in_qdq(comp_dtype, x, x_sf_fm32, x_ah_fm32) + k_qdq = in_qdq(comp_dtype, k, k_sf_fm32, k_ah_fm32) y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore - y = out_qdq( - comp_dtype, - y_qdq, - self.output_grad_scale.value, - self.output_grad_amax_history.value, - ) + y = out_qdq(comp_dtype, y_qdq, g_sf_fm32, g_ah_fm32) return y # type: ignore diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index b72e39dd2d..f622a73027 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -1171,7 +1171,9 @@ def run(fp8_injection, expected_shapes): p = nn.DenseGeneral(features=64, name='dense') if fp8_injection: p.dot_general_cls = nn.Fp8DotGeneralOp - y, initial_vars = p.init_with_output(init_key, x) + + init_fn = jax.jit(p.init_with_output) + y, initial_vars = init_fn(init_key, x) var_shapes = jax.tree_util.tree_map(jnp.shape, initial_vars) self.assertEqual(var_shapes, expected_shapes) @@ -1216,7 +1218,9 @@ def test_fp8_train_state(self): dense = nn.DenseGeneral( features=32, use_bias=True, dot_general_cls=nn.Fp8DotGeneralOp ) - variables = dense.init(init_key, x) + + init_fn = jax.jit(dense.init) + variables = init_fn(init_key, x) opt = optax.adam(learning_rate=0.1) state = train_state.TrainState.create( params=variables, tx=opt, apply_fn=dense.apply @@ -1251,15 +1255,15 @@ def loss_fn(vars): k = state.params['params']['kernel'] # Manually compute the expected amax history and scaling factors. - amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x))) - amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(k))) - amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g))) amax_from_history_x = jnp.max(amax_history_x, axis=0) amax_from_history_k = jnp.max(amax_history_k, axis=0) amax_from_history_g = jnp.max(amax_history_g, axis=0) scale_x = fp8_ops.compute_scale(amax_from_history_x, scale_x, e4m3_max) scale_k = fp8_ops.compute_scale(amax_from_history_k, scale_k, e4m3_max) scale_g = fp8_ops.compute_scale(amax_from_history_g, scale_g, e5m2_max) + amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x))) + amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(k))) + amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g))) state = train_fn(state, x, g) @@ -1290,6 +1294,39 @@ def loss_fn(vars): np.testing.assert_allclose(fp8_vars['kernel_scale'][0], scale_k) np.testing.assert_allclose(fp8_vars['output_grad_scale'][0], scale_g) + def test_fp8_meta_dtype(self): + f32 = jnp.dtype('float32') + fm32 = fp8_ops.fm32 + + # Create a scan loop with reused ah_f32 and sf_f32. So, the autograd will + # accumulate the grads of them. We expect the max op (rather than add op) + # for the accumulation by converting them to fm32 dtype. + def outer(x, ah_f32, sf_f32): + ah_fm32 = jax.lax.convert_element_type(ah_f32, fm32) + sf_fm32 = jax.lax.convert_element_type(sf_f32, fm32) + array_x = jnp.array([x], f32) + def body_fun(carry, _): + carry = fp8_ops.in_qdq(f32, carry, sf_fm32, ah_fm32) + return carry, None + array_x, _ = jax.lax.scan(body_fun, array_x, None, length=3) + return array_x[0] + + outer_fn = jax.jit(jax.grad(outer, (0, 1, 2))) + ah = jnp.array([0., 0., 0.], f32) + sf = jnp.array([1.], f32) + # 1st iteration + grads, new_ah, new_sf = outer_fn(2.0, ah, sf) + np.testing.assert_allclose(new_ah, [2., 0., 0.]) + np.testing.assert_allclose(new_sf, [1.]) + # 2nd iteration + grads, new_ah, new_sf = outer_fn(3., new_ah, new_sf) + np.testing.assert_allclose(new_ah, [3., 0., 2.]) + np.testing.assert_allclose(new_sf, [2. / 448]) + # 3rd iteration + grads, new_ah, new_sf = outer_fn(4., new_ah, new_sf) + np.testing.assert_allclose(new_ah, [4., 2., 3.]) + np.testing.assert_allclose(new_sf, [3. / 448]) + if __name__ == '__main__': absltest.main() From 54fa392ba612efd3508d85fc4b22f1b72f5d250f Mon Sep 17 00:00:00 2001 From: kaixih Date: Wed, 31 Jan 2024 19:21:03 +0000 Subject: [PATCH 2/3] Add the full attr --- flax/linen/fp8_ops.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 0f9712cd14..c9a70181e6 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -59,6 +59,12 @@ def zero(dt): def tangent_dtype(dtype): return dtype + @staticmethod + def full(shape, fill_value, dtype): + fill_value = lax.convert_element_type(fill_value, dtype.float_dtype) + out_raw = lax.full(shape, fill_value, dtype.float_dtype) + return lax.convert_element_type(out_raw, dtype) + # NOTE: by skipping some rules, this dtype can only be used underneath jit @staticmethod def global_sharded_result_handler(aval, sharding, committed, is_from_xla): From dd004c2127e6014928ddd176f516c54be48e94a4 Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 1 Feb 2024 07:31:02 +0000 Subject: [PATCH 3/3] Clean up --- flax/linen/fp8_ops.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index c9a70181e6..d3543c3bd3 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -131,10 +131,12 @@ def compute_amax_history(x, amax_history): return new_history -def qdq_and_return(x, q_dtype, sf_fm32, ah_fm32, compute_dtype): +def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): + is_fm32 = scale.dtype == fm32 and amax_history.dtype == fm32 # convert fm32->f32 so we can do math - amax_history = lax.convert_element_type(ah_fm32, jnp.float32) - scale = lax.convert_element_type(sf_fm32, jnp.float32) + if is_fm32: + amax_history = lax.convert_element_type(amax_history, jnp.float32) + scale = lax.convert_element_type(scale, jnp.float32) dtype_max = get_fp8_max(q_dtype, jnp.float32) amax_from_history = jnp.max(amax_history, axis=0) @@ -145,9 +147,10 @@ def qdq_and_return(x, q_dtype, sf_fm32, ah_fm32, compute_dtype): new_history = compute_amax_history(x, amax_history) # convert f32->fm32 so the autodiff system accumulates fp8 meta correctly - new_ah_fm32 = lax.convert_element_type(new_history, fm32) - new_sf_fm32 = lax.convert_element_type(new_scale, fm32) - return qx, new_sf_fm32, new_ah_fm32 + if is_fm32: + new_history = lax.convert_element_type(new_history, fm32) + new_scale = lax.convert_element_type(new_scale, fm32) + return qx, new_scale, new_history @partial(custom_vjp, nondiff_argnums=(0,)) @@ -275,18 +278,18 @@ def __call__(self, *args, **kwargs): comp_dtype = k.dtype x = jnp.asarray(x, comp_dtype) - x_sf_fm32 = lax.convert_element_type(self.input_scale.value, fm32) - x_ah_fm32 = lax.convert_element_type(self.input_amax_history.value, fm32) - k_sf_fm32 = lax.convert_element_type(self.kernel_scale.value, fm32) - k_ah_fm32 = lax.convert_element_type(self.kernel_amax_history.value, fm32) - g_sf_fm32 = lax.convert_element_type(self.output_grad_scale.value, fm32) - g_ah_fm32 = lax.convert_element_type( - self.output_grad_amax_history.value, fm32 + x_qdq = in_qdq( + comp_dtype, x, self.input_scale.value, self.input_amax_history.value + ) + k_qdq = in_qdq( + comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value ) - - x_qdq = in_qdq(comp_dtype, x, x_sf_fm32, x_ah_fm32) - k_qdq = in_qdq(comp_dtype, k, k_sf_fm32, k_ah_fm32) y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore - y = out_qdq(comp_dtype, y_qdq, g_sf_fm32, g_ah_fm32) + y = out_qdq( + comp_dtype, + y_qdq, + self.output_grad_scale.value, + self.output_grad_amax_history.value, + ) return y # type: ignore