forked from oobabooga/text-generation-webui
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add StreamingLLM for llamacpp & llamacpp_HF (2nd attempt) (oobabooga#…
- Loading branch information
1 parent
8e9fb06
commit c89b813
Showing
7 changed files
with
147 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import torch | ||
|
||
from modules import shared | ||
from modules.logging_colors import logger | ||
|
||
|
||
def process_llamacpp_cache(model, new_sequence, past_sequence): | ||
i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence) | ||
overlap_length = i2 - i1 + 1 | ||
|
||
# Do StreamingLLM if i1 > 0 (ie the longest common subsequence is not a prefix) | ||
# and the overlap length is sufficiently long. | ||
if i1 > 0 and overlap_length > 0.2 * len(new_sequence): | ||
|
||
new_sequence = torch.tensor(new_sequence) | ||
past_sequence = torch.tensor(past_sequence) | ||
|
||
prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1]) | ||
sink_length = prefix_length | ||
if sink_length < shared.args.attention_sink_size: | ||
sink_length = shared.args.attention_sink_size | ||
|
||
removed_length = i1 - sink_length | ||
|
||
matching_prefix = past_sequence[:prefix_length] | ||
removed_chunk = past_sequence[sink_length:i1] | ||
overlapping_sequence = new_sequence[j1:j2 + 1] | ||
added_chunk = new_sequence[j2 + 1:] | ||
|
||
# print(past_sequence) | ||
# print(new_sequence) | ||
|
||
print() | ||
print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix))) | ||
print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk))) | ||
print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk))) | ||
print() | ||
|
||
# Remove interval [sink_length, sink_length + removed_length) from the context | ||
# Subtract removed_length from model.n_tokens | ||
model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length) | ||
model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length) | ||
|
||
new_sequence = new_sequence.tolist() | ||
model.input_ids[:j2 + 1] = new_sequence[:j2 + 1] | ||
model.n_tokens = j2 + 1 | ||
|
||
return new_sequence[:j2 + 1] | ||
else: | ||
return past_sequence | ||
|
||
|
||
def find_prefix_length(past_seq, seq_tensor): | ||
''' | ||
Given two torch tensors, finds the length of the longest | ||
common prefix between the two. | ||
''' | ||
min_length = min(past_seq.shape[0], seq_tensor.shape[0]) | ||
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) | ||
if len(indices) > 0: | ||
prefix_length = indices[0].item() | ||
else: | ||
prefix_length = min_length | ||
|
||
return prefix_length | ||
|
||
|
||
def find_longest_common_substring_indices(list1, list2): | ||
''' | ||
Given two lists, solves the Longest Common Substring problem. | ||
It returns the indices where the substring starts and ends in | ||
s1 and s2. | ||
Example: | ||
ir, jr, ir2, jr2 = find_longest_common_substring_indices(s1, s2) | ||
print(s1[ir:jr + 1]) | ||
print(s2[ir2:jr2 + 1]) | ||
Adapted from | ||
https://rosettacode.org/wiki/Longest_common_substring#Python | ||
''' | ||
|
||
len_list1, len_list2 = len(list1), len(list2) | ||
start_index_list1, end_index_list1 = 0, -1 | ||
start_index_list2, end_index_list2 = 0, -1 | ||
|
||
for index1 in range(len_list1): | ||
try: | ||
index2 = list2.index(list1[index1]) | ||
except ValueError: | ||
continue | ||
while index2 >= 0: | ||
temp_index1, temp_index2 = index1, index2 | ||
while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]: | ||
if temp_index1 - index1 >= end_index_list1 - start_index_list1: | ||
start_index_list1, end_index_list1 = index1, temp_index1 | ||
start_index_list2, end_index_list2 = index2, temp_index2 | ||
|
||
temp_index1 += 1 | ||
temp_index2 += 1 | ||
try: | ||
index2 = list2.index(list1[index1], index2 + 1) | ||
except ValueError: | ||
break | ||
|
||
return start_index_list1, end_index_list1, start_index_list2, end_index_list2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters