Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Prefix sharing. #53

Merged
merged 18 commits into from
Nov 28, 2023
Merged

[Feature] Prefix sharing. #53

merged 18 commits into from
Nov 28, 2023

Conversation

Duyi-Wang
Copy link
Contributor

@Duyi-Wang Duyi-Wang commented Nov 14, 2023

Support Llama, chatGLM2, Baichuan, and Opt. Not support chatGLM 1 model.

@Duyi-Wang Duyi-Wang marked this pull request as draft November 14, 2023 08:52
@Duyi-Wang Duyi-Wang added the enhancement New feature or request label Nov 15, 2023
@Duyi-Wang Duyi-Wang marked this pull request as ready for review November 22, 2023 08:20
@@ -24,27 +24,41 @@ class KVCacheManager {
this->layers = layers;
this->cachedKeys = new KVCacheTensor<KVCacheT>[layers];
this->cachedValues = new KVCacheTensor<KVCacheT>[layers];
this->cachedPrefixKeys = new KVCacheTensor<KVCacheT>[layers];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If prefix_sharing=false, do not need to allocate it (although small memory).
Suggest allocating it when really needed.

this->getPositionIds(prefixIDs, batchSize, pastSeqLen, 0);

free(prefixIDs);
ids = newIDs;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance to free the ID in future since it is dynamically allocated?


this->prepareAttnMask(prefixIDs, 0);

this->getPositionIds(prefixIDs, batchSize, pastSeqLen, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to call getPositionIds?

p[keyLen - 1] * ctx->attFactor);
p[2] * ctx->attFactor, p[strideC - 3] * ctx->attFactor, p[strideC - 2] * ctx->attFactor,
p[strideC - 1] * ctx->attFactor);
// for (int qki = 0; qki < queryLen; qki++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not need, pls remove such commented code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to print the whole QK score and attention mask matrix.

memcpy(newIDs + inputSeqLen * bs, ids + seqLen * bs + pastSeqLen, inputSeqLen * sizeof(int));
}

this->prepareAttnMask(prefixIDs, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this step is?

@pujiang2018 pujiang2018 merged commit 637eb49 into intel:main Nov 28, 2023
1 check passed
@Duyi-Wang Duyi-Wang deleted the prefix branch November 29, 2023 08:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants