-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix logic for determining the number of cache blocks #98
Conversation
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, but a few comments before approval. Also do we have an image available we can try this out?
nt_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.next_token_params[1] | ||
total_num_gpu_blocks = int(nt_cache_block_ratio * memory_scaling_model.free_memory // cache_block_size) | ||
# we may need to increase the safety margin a bit to ensure that prefill forward does not run OOM | ||
recommend_safety_margin = 5 + int(100*(1.0 - (1.0 - nt_cache_block_ratio)/(1.0 - pf_cache_block_ratio))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have a comment around this line as to what is being done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some explanation now. This approach isn't ideal, and might affect the maximum throughput we can achieve with the server. However, I can't see any other way to ensure robustness without re-implementing the batching logic to interact with the KVCacheManager.
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None) | ||
|
||
# speculator revision | ||
SPECULATOR_REVISION = os.getenv("SPECULATOR_REVISION", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this used for when loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it's analogous to MODEL_REVISION
, the specific commit hash of the model to load. Like one of these I think: https://huggingface.co/ibm/granite-7b-lab-accelerator/commits/main
) | ||
except: | ||
# if something goes wrong during forward, we still need to set the sequence ids | ||
batch.sequence_ids = cache_data.sequence_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to just move this from the bottom of the method to above the call to self.model(...)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think cache_data.sequence_ids
only gets populated within the call to self.model(...)
so we can't move it beforehand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels a bit fragile, wonder if it would be better to revert to prior state (if possible) if the call to call to self.model
fails? ideally within that call... i.e. avoid partial success.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that its fragile, and there might be a better way to address it from within the function. Not sure whether to prioritize that at this stage though.
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Signed-off-by: Nick Hill <[email protected]>
When we deploy spec decoding in prod., we are frequently seeing the servers running out of free blocks. We have determined that this is due to two issues: 1. The constraint on `SPECULATOR_MAX_BATCH_SIZE` is not enough to avoid running into memory pressure due to speculation - we need to able ensure that we do not speculate on batches that may have a small "size" but very large weight. 2. The computation of the number of blocks is very wrong in most cases. 1. I have introduced an additional constraint that says we should only speculate on batches with weight up to 75% of the weight limit. This should ensure that we never speculate when we are close to the memory limits. 2. I have written new code to calculate the number of KV cache blocks. This calculation uses the memory scaling coefficients that we have learned at startup. In particular, it uses to the learned coefficients to figure out what % of the memory capacity needs to be set aside for cache blocks. 3. In the above calculation, I use the next token coefficient, rather than the prefill coefficient, since typically during next token phase the KV cache blocks comprise a relatively large percentage of the total memory consumption and we need to be able to handle this worst-case. However, this means that during prefill steps, we may not have enough memory leftover to store the auxiliary data structures we need for a forward pass. There isn't really a clean way to handle this other than re-writing the router logic to be block-aware, but what we can do is recommend to the user that they should increase the batch safety margin to a certain level to ensure that prefills will not run OOM. I've added a print statement to provide this guidance. 4. I now load the speculator before learning the memory scaling model since we also need to take that into account when measuring the amount of free memory. These changes, together with setting the `BATCH_SAFETY_MARGIN=35`, seems to result in robust behaviour for both `llama3-8b` and `granite-20b`. We no longer need to manually set the number of KV cache blocks in the latter case. n/a --------- Signed-off-by: Thomas Parnell <[email protected]>
When we deploy spec decoding in prod., we are frequently seeing the servers running out of free blocks. We have determined that this is due to two issues: 1. The constraint on `SPECULATOR_MAX_BATCH_SIZE` is not enough to avoid running into memory pressure due to speculation - we need to able ensure that we do not speculate on batches that may have a small "size" but very large weight. 2. The computation of the number of blocks is very wrong in most cases. 1. I have introduced an additional constraint that says we should only speculate on batches with weight up to 75% of the weight limit. This should ensure that we never speculate when we are close to the memory limits. 2. I have written new code to calculate the number of KV cache blocks. This calculation uses the memory scaling coefficients that we have learned at startup. In particular, it uses to the learned coefficients to figure out what % of the memory capacity needs to be set aside for cache blocks. 3. In the above calculation, I use the next token coefficient, rather than the prefill coefficient, since typically during next token phase the KV cache blocks comprise a relatively large percentage of the total memory consumption and we need to be able to handle this worst-case. However, this means that during prefill steps, we may not have enough memory leftover to store the auxiliary data structures we need for a forward pass. There isn't really a clean way to handle this other than re-writing the router logic to be block-aware, but what we can do is recommend to the user that they should increase the batch safety margin to a certain level to ensure that prefills will not run OOM. I've added a print statement to provide this guidance. 4. I now load the speculator before learning the memory scaling model since we also need to take that into account when measuring the amount of free memory. These changes, together with setting the `BATCH_SAFETY_MARGIN=35`, seems to result in robust behaviour for both `llama3-8b` and `granite-20b`. We no longer need to manually set the number of KV cache blocks in the latter case. n/a --------- Signed-off-by: Thomas Parnell <[email protected]>
Motivation
When we deploy spec decoding in prod., we are frequently seeing the servers running out of free blocks. We have determined that this is due to two issues:
SPECULATOR_MAX_BATCH_SIZE
is not enough to avoid running into memory pressure due to speculation - we need to able ensure that we do not speculate on batches that may have a small "size" but very large weight.Modifications
Result
These changes, together with setting the
BATCH_SAFETY_MARGIN=35
, seems to result in robust behaviour for bothllama3-8b
andgranite-20b
. We no longer need to manually set the number of KV cache blocks in the latter case.Related Issues
n/a