Question: Does low-bit config reduce TPU HBM memory usage when training? #66
Replies: 2 comments
-
hello and thanks # hidden_state an array with the shape of (N-Dims)
# q_proj query proj a Neuron in Flax
# q_flax.QDotGeneral built-in FJFormer (Computation backend for EasyDel)
if config.bits is not None:
_dot_general_cls = q_config.fully_quantized(
fwd_bits=self.config.bits,
bwd_bits=self.config.bits
)
else:
_dot_general_cls = None
dot_general_cls = q_flax.QDotGeneral(_dot_general_cls)
q_proj = nn.Dense(..., dot_general=dot_general_cls)
# now this layer will take any array and do the matmul or dot operation in given bits for example in 6,8 or 4
# instead of using jax.lax.dot_general we use fjformer QDotGeneral like this
#
# from fjformer.bits import q_flax, config as q_config
#
# dot_general = q_flax.QDotGeneral(q_config.fully_quantized(
# bwd_bits=4,
# fwd_bits=4
# ))
hidden_state = q_proj(hidden_state)
# hidden_state is still in bfloat16 or float16 or float32 but operations are computed in given bits Run and find Outso let see how will the code work or compute with running the code import chex
from fjformer.bits import config, q_flax as q
from flax import linen as nn
from typing import Optional
import jax
class MLP(nn.Module):
qnt_config: Optional[config.DotGeneral] = None
@nn.compact
def __call__(self, inputs: chex.Array):
dot_general = q.QDotGeneral(self.qnt_config) # in case of passing None the jax.lax.dot_general will be used
x = nn.Dense(dot_general=dot_general, features=inputs.shape[-1] * 4)(inputs)
x = nn.silu(x)
x = nn.Dense(dot_general=dot_general, features=inputs.shape[-1])(x)
return x + inputs
def init_and_eval(name, inputs, mlp_block, init_seed=0, eval_seed=0):
params = mlp_block.init(jax.random.PRNGKey(init_seed), inputs)
out = mlp_block.apply(params, inputs, rngs={'params': jax.random.key(eval_seed)})
print(f"{name}:\n", out)
def main():
int8_config = config.fully_quantized(fwd_bits=8, bwd_bits=8)
mlp_fp16 = MLP(qnt_config=None)
mlp_int8 = MLP(qnt_config=int8_config)
input_ = jax.random.normal(jax.random.key(0), (1, 600, 4))
init_and_eval('mlp_fp16', input_, mlp_fp16)
init_and_eval('mlp_int8', input_, mlp_int8)
if __name__ == "__main__":
main() and the output must be like this Float16 Computation
INT 8 Computation
|
Beta Was this translation helpful? Give feedback.
-
if you still have any other questions I'll be happy to answer them and help you |
Beta Was this translation helpful? Give feedback.
-
Hello,
Firstly, I'd like to express my appreciation for your work on this repository. I noticed that it supports low-bit (4 or 8 bits) formats during training, which is quite intriguing.
I have a query regarding TPU compatibility, particularly before TPUv4. As TPUs typically don't support low-bit formats like 4 or 8 bits until TPUv4 (which supports int8), I'm curious about how this implementation works. My current understanding is that the code might be converting 4 or 8-bit formats into bfloat16 or float16 formats. If this is the case, would it imply that the memory usage reduction typically expected from lower bit formats might not be realized?
Could you please clarify if my understanding is correct? Thanks for your time and effort in developing and maintaining this repository.
Beta Was this translation helpful? Give feedback.
All reactions