Skip to content

Commit

Permalink
Merge pull request #4229 from wenscarl:scale_type_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681055894
  • Loading branch information
Flax Authors committed Oct 1, 2024
2 parents 96b0edc + 5edc65a commit b9bbc98
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 93 deletions.
239 changes: 147 additions & 92 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def quantize(x, q_dtype, scale, compute_dtype):
def dequantize(x, dq_dtype, scale):
return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape)


def quantize_dequantize(x, q_dtype, scale, compute_dtype):
def qdq(x, q_dtype, scale, compute_dtype):
qx = quantize(x, q_dtype, scale, compute_dtype)
return dequantize(qx, x.dtype, scale)

Expand All @@ -165,8 +164,8 @@ def compute_amax_history(x, amax_history):
return new_history


def quantize_and_update(
x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False
def update_fp8_meta(
x, q_dtype, scale, amax_history
):
is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32)
# convert fm32->f32 so we can do math
Expand All @@ -181,20 +180,20 @@ def quantize_and_update(
new_scale = compute_scale(amax_from_history, scale, dtype_max)
new_history = compute_amax_history(x, amax_history)

# convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly
if is_fmax32:
new_history = lax.convert_element_type(new_history, fp32_max_grad)
new_scale = lax.convert_element_type(new_scale, fp32_max_grad)

# Quantize the input
if not use_direct_quant:
qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)
return qx, new_scale, new_history

return new_scale, new_history

def quantize_dequantize_update(x, q_dtype, scale, amax_history, compute_dtype):
updated_scale, updated_history = update_fp8_meta(x, q_dtype, scale, amax_history)
qdq_x = qdq(x, q_dtype, _fm32_to_float32(updated_scale), compute_dtype)
return qdq_x, updated_scale, updated_history

return qx, new_scale, new_history
def _fm32_to_float32(value):
if value.dtype == fm32:
return lax.convert_element_type(value, jnp.float32)
return value

def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
Expand Down Expand Up @@ -242,14 +241,14 @@ def dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,

@partial(custom_vjp, nondiff_argnums=(0, 1))
def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history):
qin, _, _ = quantize_and_update(
qin, _, _ = quantize_dequantize_update(
inp, q_dtype, scale, amax_history, compute_dtype
)
return qin


def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history):
qin, new_scale, new_history = quantize_and_update(
qin, new_scale, new_history = quantize_dequantize_update(
inp, q_dtype, scale, amax_history, compute_dtype
)
return qin, (new_scale, new_history)
Expand All @@ -275,7 +274,7 @@ def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history):

def out_qdq_bwd(compute_dtype, q_dtype, res, g):
scale, amax_history = res
q_g, new_scale, new_history = quantize_and_update(
q_g, new_scale, new_history = quantize_dequantize_update(
g, q_dtype, scale, amax_history, compute_dtype
)
return q_g, new_scale, new_history
Expand All @@ -284,91 +283,103 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g):
out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)


def q_dot_dq_impl(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
is_training
):
new_lhs_scale, new_lhs_amax_history = quantize_and_update(
lhs,
jnp.float8_e4m3fn,
lhs_scale,
lhs_amax_history,
compute_dtype,
use_direct_quant=True
)
new_rhs_scale, new_rhs_amax_history = quantize_and_update(
rhs,
jnp.float8_e4m3fn,
rhs_scale,
rhs_amax_history,
compute_dtype,
use_direct_quant=True
@partial(custom_vjp, nondiff_argnums=(0, 1))
def in_q(compute_dtype, q_dtype, inp, scale, amax_history):
new_scale, _ = update_fp8_meta(inp, q_dtype, scale, amax_history)
qin = quantize(inp, q_dtype, _fm32_to_float32(new_scale), compute_dtype)
return qin, new_scale

def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history):
new_scale, new_history = update_fp8_meta(inp, q_dtype, scale, amax_history)
qin = quantize(inp, q_dtype, _fm32_to_float32(new_scale), compute_dtype)
return (qin, new_scale), (new_scale, new_history)

def in_q_bwd(compute_dtype, q_dtype, res, _):
new_scale, new_history = res
# We don't compute gradients for inp, scale and amax_history, but we pass through scale and history
return None, new_scale, new_history

in_q.defvjp(in_q_fwd, in_q_bwd)


@partial(custom_vjp, nondiff_argnums=(0, ))
def out_dq(dq_type, lhs_scale, rhs_scale, out):
q_out = dequantize(
out,
dq_type,
_fm32_to_float32(lhs_scale) * _fm32_to_float32(rhs_scale)
)
return q_out

def out_dq_fwd(dq_type, lhs_scale, rhs_scale, out):
return out_dq(dq_type, lhs_scale, rhs_scale, out), None

def out_dq_bwd(dq_type, _, g):
return None, None, g

out_dq.defvjp(out_dq_fwd, out_dq_bwd)

q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type)
q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type)

def quantized_dot_impl(
lhs,
q_lhs,
lhs_scale, # actualy new lhs scale
rhs,
q_rhs, # actualy new rhs scale
rhs_scale,
out_grad_scale, # old out grad scale
out_grad_amax_history, # old out grad amax history
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
is_training
):
out = lax.dot_general(
q_lhs,
q_rhs,
dimension_numbers,
preferred_element_type=preferred_element_type,
precision=lax.Precision.DEFAULT,
)

out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale)
if is_training:
res = (
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
new_lhs_scale,
new_rhs_scale,
rhs_scale,
out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
out_grad_amax_history,
)
return out, res
else:
return out


@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11))
def q_dot_dq(
def quantized_dot(
lhs,
q_lhs,
lhs_scale, # actualy new lhs scale
rhs,
lhs_scale,
q_rhs,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
out_grad_scale, # old out grad scale
out_grad_amax_history, # old out grad amax history
compute_dtype,
dimension_numbers,
precision=None,
preferred_element_type=None
):
return q_dot_dq_impl(
return quantized_dot_impl(
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
Expand All @@ -377,29 +388,28 @@ def q_dot_dq(
is_training=False,
)


def q_dot_dq_fwd(
def quantized_dot_fwd(
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
):
return q_dot_dq_impl(
return quantized_dot_impl(
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
Expand All @@ -408,8 +418,7 @@ def q_dot_dq_fwd(
is_training=True
)


def q_dot_dq_bwd(
def quantized_dot_bwd(
compute_dtype,
dimension_numbers,
precision,
Expand All @@ -419,27 +428,23 @@ def q_dot_dq_bwd(
):
(
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
new_lhs_scale,
new_rhs_scale,
rhs_scale,
out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
out_grad_amax_history,
) = res

new_out_grad_scale, new_out_grad_amax_history = quantize_and_update(
new_out_grad_scale, new_out_grad_amax_history = update_fp8_meta(
g,
jnp.float8_e5m2,
out_grad_scale,
out_grad_amax_history,
compute_dtype,
use_direct_quant=True
)

q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type)
q_g = quantize(g, jnp.float8_e5m2, _fm32_to_float32(new_out_grad_scale), preferred_element_type)

grad_lhs = dot_general_transpose_lhs(
q_g,
Expand All @@ -449,7 +454,11 @@ def q_dot_dq_bwd(
precision=lax.Precision.HIGHEST,
preferred_element_type=preferred_element_type,
)
grad_lhs = dequantize(grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale)
grad_lhs = dequantize(
grad_lhs,
preferred_element_type,
_fm32_to_float32(rhs_scale) * _fm32_to_float32(new_out_grad_scale)
)

grad_rhs = dot_general_transpose_rhs(
q_g,
Expand All @@ -459,21 +468,67 @@ def q_dot_dq_bwd(
precision=lax.Precision.HIGHEST,
preferred_element_type=preferred_element_type,
)
grad_rhs = dequantize(grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale)
grad_rhs = dequantize(
grad_rhs,
preferred_element_type,
_fm32_to_float32(lhs_scale) * _fm32_to_float32(new_out_grad_scale)
)

return (
grad_lhs,
None,
None,
grad_rhs,
new_lhs_scale,
new_rhs_scale,
None,
None,
new_out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
new_out_grad_amax_history,
)

q_dot_dq.defvjp(q_dot_dq_fwd, q_dot_dq_bwd)
quantized_dot.defvjp(quantized_dot_fwd, quantized_dot_bwd)

# Convenience wrappers for the quantize-dot-dequantize
def q_dot_dq(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision=None,
preferred_element_type=None
):
q_lhs, new_lhs_scale = in_q(
compute_dtype, jnp.float8_e4m3fn, lhs, lhs_scale, lhs_amax_history
)
q_rhs, new_rhs_scale = in_q(
compute_dtype, jnp.float8_e4m3fn, rhs, rhs_scale, rhs_amax_history
)
y = quantized_dot(
lhs,
q_lhs,
new_lhs_scale,
rhs,
q_rhs,
new_rhs_scale,
out_grad_scale,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type
)
y = out_dq(
dq_type=preferred_element_type,
lhs_scale=new_lhs_scale,
rhs_scale=new_rhs_scale,
out=y
)
return y # type: ignore

@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_precision(
Expand Down
Loading

0 comments on commit b9bbc98

Please sign in to comment.