Introduce delayed sampling mechanism #84
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This change introduces a mechanism called "delayed sampling" which aims at minimizing the CPU overhead related to output tokens post-processing and next token scheduling time by overlapping the CPU-active part with device computations.
When delayed sampling is enabled first prompt model execution schedules the model.forward() and logits computation on the device followed by immediately returning an output filled with invalid token ids, not waiting for the computation to complete. The output logits are only gathered and sampled in the subsequent model execution, which again schedules next model.forward() and logits computation invocation not waiting for the results to come back, but rather returning the previously collected and sampled output token ids. This process continues for the entire sequence length resulting in the last token which is computed redundantly being discarded.
Please review @madamczykhabana , @kzawora-intel