Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Use custom grad accumulation for FP8 params #3623

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 86 additions & 10 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,81 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import numpy as np
import warnings
from functools import partial

from jax import custom_jvp, custom_vjp, lax, random
from jax import numpy as jnp
from jax._src import core
from jax._src import dtypes

from flax.linen import initializers, module

OVERWRITE_WITH_GRADIENT = '_overwrite_with_gradient'

# Define a custom dtype for FP8 meta params.
class Fp8MetaTyRules:
# tell JAX how to lower this dtype to an HLO dtype
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((), dtype.float_dtype)

# allow conversions to and from the corresponding float type
@staticmethod
def convert_from(fp8_meta_dtype, other_dtype) -> bool:
return fp8_meta_dtype.float_dtype == other_dtype

@staticmethod
def convert_to(other_dtype, fp8_meta_dtype) -> bool:
return fp8_meta_dtype.float_dtype == other_dtype

# define how autodiff should accumulate these values
@staticmethod
def add(dt, x, y):
from_fp8_meta = partial(lax.convert_element_type, new_dtype=dt.float_dtype)
to_fp8_meta = partial(lax.convert_element_type, new_dtype=dt)
return to_fp8_meta(lax.max(from_fp8_meta(x), from_fp8_meta(y)))

@staticmethod
def zero(dt):
neginf = np.array(-np.inf if dtypes.supports_inf(dt.float_dtype)
else dtypes.finfo(dt.float_dtype).min, dt.float_dtype)
return lax.convert_element_type(neginf, dt)

@staticmethod
def tangent_dtype(dtype):
return dtype

@staticmethod
def full(shape, fill_value, dtype):
fill_value = lax.convert_element_type(fill_value, dtype.float_dtype)
out_raw = lax.full(shape, fill_value, dtype.float_dtype)
return lax.convert_element_type(out_raw, dtype)

# NOTE: by skipping some rules, this dtype can only be used underneath jit
@staticmethod
def global_sharded_result_handler(aval, sharding, committed, is_from_xla):
raise NotImplementedError("convert back under the jit")


# class to use as second argument to jax.dtypes.issubdtype
class fp8_meta_dtype(dtypes.extended): pass

# parameterized datatype for use in e.g. lax.convert_element_type
@dataclasses.dataclass(frozen=True)
class fp8_meta_dtype_wrapper(dtypes.ExtendedDType):
float_dtype: dtypes.DType
_rules: type = Fp8MetaTyRules
type: type = fp8_meta_dtype

def __repr__(self) -> str:
nbits = dtypes.finfo(self.float_dtype).bits
return f'fp8_meta{nbits}'
name = property(__repr__)

fm32 = fp8_meta_dtype_wrapper(jnp.float32)

def get_fp8_max(fp8_dtype, out_dtype):
assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2)
Expand Down Expand Up @@ -60,20 +125,31 @@ def compute_scale(amax, scale, fp8_max, margin=0):
return 1.0 / sf


def compute_scale_and_amax_history(x, q_dtype, scale, amax_history):
dtype_max = get_fp8_max(q_dtype, jnp.float32)
amax_update = jnp.max(jnp.abs(x)).astype(scale.dtype)
def compute_amax_history(x, amax_history):
amax_update = jnp.max(jnp.abs(x)).astype(amax_history.dtype)
new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update)
amax_from_history = jnp.max(new_history, axis=0)
new_scale = compute_scale(amax_from_history, scale, dtype_max)
return new_scale, new_history
return new_history


def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):
qx = quantize_dequantize(x, q_dtype, scale, compute_dtype)
new_scale, new_history = compute_scale_and_amax_history(
x, q_dtype, scale, amax_history
)
is_fm32 = scale.dtype == fm32 and amax_history.dtype == fm32
# convert fm32->f32 so we can do math
if is_fm32:
amax_history = lax.convert_element_type(amax_history, jnp.float32)
scale = lax.convert_element_type(scale, jnp.float32)

dtype_max = get_fp8_max(q_dtype, jnp.float32)
amax_from_history = jnp.max(amax_history, axis=0)
new_scale = compute_scale(amax_from_history, scale, dtype_max)

qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)

new_history = compute_amax_history(x, amax_history)

# convert f32->fm32 so the autodiff system accumulates fp8 meta correctly
if is_fm32:
new_history = lax.convert_element_type(new_history, fm32)
new_scale = lax.convert_element_type(new_scale, fm32)
return qx, new_scale, new_history


Expand Down
47 changes: 42 additions & 5 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,9 @@ def run(fp8_injection, expected_shapes):
p = nn.DenseGeneral(features=64, name='dense')
if fp8_injection:
p.dot_general_cls = nn.Fp8DotGeneralOp
y, initial_vars = p.init_with_output(init_key, x)

init_fn = jax.jit(p.init_with_output)
y, initial_vars = init_fn(init_key, x)
var_shapes = jax.tree_util.tree_map(jnp.shape, initial_vars)
self.assertEqual(var_shapes, expected_shapes)

Expand Down Expand Up @@ -1216,7 +1218,9 @@ def test_fp8_train_state(self):
dense = nn.DenseGeneral(
features=32, use_bias=True, dot_general_cls=nn.Fp8DotGeneralOp
)
variables = dense.init(init_key, x)

init_fn = jax.jit(dense.init)
variables = init_fn(init_key, x)
opt = optax.adam(learning_rate=0.1)
state = train_state.TrainState.create(
params=variables, tx=opt, apply_fn=dense.apply
Expand Down Expand Up @@ -1251,15 +1255,15 @@ def loss_fn(vars):
k = state.params['params']['kernel']

# Manually compute the expected amax history and scaling factors.
amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x)))
amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(k)))
amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g)))
amax_from_history_x = jnp.max(amax_history_x, axis=0)
amax_from_history_k = jnp.max(amax_history_k, axis=0)
amax_from_history_g = jnp.max(amax_history_g, axis=0)
scale_x = fp8_ops.compute_scale(amax_from_history_x, scale_x, e4m3_max)
scale_k = fp8_ops.compute_scale(amax_from_history_k, scale_k, e4m3_max)
scale_g = fp8_ops.compute_scale(amax_from_history_g, scale_g, e5m2_max)
amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x)))
amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(k)))
amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g)))

state = train_fn(state, x, g)

Expand Down Expand Up @@ -1290,6 +1294,39 @@ def loss_fn(vars):
np.testing.assert_allclose(fp8_vars['kernel_scale'][0], scale_k)
np.testing.assert_allclose(fp8_vars['output_grad_scale'][0], scale_g)

def test_fp8_meta_dtype(self):
f32 = jnp.dtype('float32')
fm32 = fp8_ops.fm32

# Create a scan loop with reused ah_f32 and sf_f32. So, the autograd will
# accumulate the grads of them. We expect the max op (rather than add op)
# for the accumulation by converting them to fm32 dtype.
def outer(x, ah_f32, sf_f32):
ah_fm32 = jax.lax.convert_element_type(ah_f32, fm32)
sf_fm32 = jax.lax.convert_element_type(sf_f32, fm32)
array_x = jnp.array([x], f32)
def body_fun(carry, _):
carry = fp8_ops.in_qdq(f32, carry, sf_fm32, ah_fm32)
return carry, None
array_x, _ = jax.lax.scan(body_fun, array_x, None, length=3)
return array_x[0]

outer_fn = jax.jit(jax.grad(outer, (0, 1, 2)))
ah = jnp.array([0., 0., 0.], f32)
sf = jnp.array([1.], f32)
# 1st iteration
grads, new_ah, new_sf = outer_fn(2.0, ah, sf)
np.testing.assert_allclose(new_ah, [2., 0., 0.])
np.testing.assert_allclose(new_sf, [1.])
# 2nd iteration
grads, new_ah, new_sf = outer_fn(3., new_ah, new_sf)
np.testing.assert_allclose(new_ah, [3., 0., 2.])
np.testing.assert_allclose(new_sf, [2. / 448])
# 3rd iteration
grads, new_ah, new_sf = outer_fn(4., new_ah, new_sf)
np.testing.assert_allclose(new_ah, [4., 2., 3.])
np.testing.assert_allclose(new_sf, [3. / 448])


if __name__ == '__main__':
absltest.main()
Loading