Skip to content

Commit

Permalink
Remove hardcoded value from softmax in flat_pa (HabanaAI#280)
Browse files Browse the repository at this point in the history
This PR removes the hardcoded value used to normalize softmax in flat_pa
. Current approach is to use the global maximum as it is very easy to
compute, but it has the drawback that other samples in a batch might
slightly affect numerical stability.

This is a first step to eliminated some of the INF/NaN issues we see in
certain configurations and by no means this is a complete solutions.
This needs to be revised in the future.
  • Loading branch information
madamczykhabana authored and zhouyu5 committed Sep 13, 2024
1 parent 9cec48d commit 7b4d448
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,18 @@ def block2batch(tensor, block_mapping):


def block_softmax(batch_size, attn, block_mapping):
attn.sub_(10.0)
# We're using global maximum to decrease the exponent as
# it's fast to compute and performs reasonably well.
# This is by no means a final solution and needs to
# be properly addressed in the future.
#
# Additionally there's a bug where 'max' is not parallelized
# across TPC cores, so we need to split the tensor manually
# instead of simply doing attn_max = attn.max()

tail_dims = tuple(range(1, attn.dim()))
attn_max = attn.amax(tail_dims).amax()
attn.sub_(attn_max)
attn = attn.exp_()
sums = attn.sum(dim=-1).unsqueeze(-1)
sums = block2batch(sums, block_mapping)
Expand Down

0 comments on commit 7b4d448

Please sign in to comment.