Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] Reuse prefill of acc lib for pipeline #12279

Merged
merged 13 commits into from
Oct 28, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-output-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=960)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path
Expand All @@ -60,8 +62,10 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
optimize_model=True,
pipeline=True,
max_output_len=args.max_output_len,
max_prompt_len=args.max_prompt_len,
torch_dtype=torch.float16,
attn_implementation="eager")
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-output-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=960)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path
Expand All @@ -67,7 +69,10 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
optimize_model=True,
pipeline=True,
max_output_len=args.max_output_len,
attn_implementation="eager")
max_prompt_len=args.max_prompt_len,
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def optimize_npu_model(cls, *args, **kwargs):
import convert_llm
convert_llm(llm,
kv_len=max_output_len,
max_prompt_len=max_prompt_len,
rnwang04 marked this conversation as resolved.
Show resolved Hide resolved
transpose_value_cache=transpose_value_cache)

return model
Expand Down
70 changes: 45 additions & 25 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,44 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
model.lm_head = new_linear


def convert_llama(
model: torch.nn.Module,
max_output_len=1024,
max_prompt_len=1024,
decoder=False,
inter_pp=None,
intra_pp=None,
transpose_value_cache=True,
):
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
from transformers.models.llama.modeling_llama import LlamaModel

if decoder:
decode_runner = DecodeRunner(
model,
max_seq_len=max_output_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
)
else:
decode_runner = None
prefill_runner = PrefillRunner(
model,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
)
llama_model_forward = gen_llama_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, LlamaModel, llama_model_forward)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
convert_forward(model, LlamaForCausalLM, llama2_casullm_forward)


def optimize_llm(
model: torch.nn.Module,
max_output_len=1024,
Expand All @@ -168,31 +206,13 @@ def optimize_llm(
intra_pp = 2
if inter_pp is None:
inter_pp = 2 if group_size == 0 else 8

from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
from transformers.models.llama.modeling_llama import LlamaModel

decode_runner = DecodeRunner(
model,
max_seq_len=max_output_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
)
prefill_runner = PrefillRunner(
model,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
)
llama_model_forward = gen_llama_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, LlamaModel, llama_model_forward)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
convert_forward(model, LlamaForCausalLM, llama2_casullm_forward)
convert_llama(model,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
decoder=True,
transpose_value_cache=transpose_value_cache)
elif model.config.model_type == "qwen2" and model.config.num_hidden_layers == 28:
# for qwen2-1.5B and qwen2-7B
if intra_pp is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,85 +63,123 @@ def generate(
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
"Input plus output tokens should not exceed max_output_len.")

# start generate_serve by Thread
thread = threading.Thread(target=generate_serve,
args=(self.kv_len, self.num_head,
self.head_dim, self.num_layers,
self.transpose_value_cache,
new_tokens - 1))
thread.start()

in_pipe_path = "\\\\.\\pipe\\llminputpipe"
out_pipe_path = "\\\\.\\pipe\\llmoutputpipe"

while True:
try:
input_pipe = open(in_pipe_path, "wb")
except:
print('Waiting for input pipe')
time.sleep(1)
else:
break

while True:
try:
output_pipe = open(out_pipe_path, "rb")
except:
print('Waiting for output pipe')
time.sleep(1)
else:
break
output_tokens = []

bdata = b''
for i in range(0, input_length):
d = int(numpy_input[i])
bdata = bdata + d.to_bytes(4, sys.byteorder)
with tempfile.TemporaryDirectory() as temp_dir:
# run prefill with PrefillRunner
output = self(input_ids=inputs,
attention_mask=torch.ones(1, inputs.shape[1]).int())
logits = output.logits
input_id = torch.argmax(logits[:, -1, :], dim=1)
input_id.to(torch.int32).numpy().tofile(os.path.join(temp_dir, "input_id.bin"))
position = np.int64(inputs.shape[1])
position.tofile(os.path.join(temp_dir, "position.bin"))
past_key_values = output.past_key_values
key_cache = past_key_values.key_cache
value_cache = past_key_values.value_cache
for layer in range(self.num_layers):
key_ = key_cache[layer]
val_ = value_cache[layer]
new_size = (
key_.size(0),
key_.size(1),
self.kv_len,
key_.size(3),
)
key = key_.as_strided(new_size, key_.stride(), storage_offset=0)
if not self.transpose_value_cache:
val = val_.as_strided(new_size, val_.stride(), storage_offset=0)
else:
new_size = (
val_.size(0),
val_.size(1),
val_.size(3),
self.kv_len,
)
val_cache = val_.transpose(-1, -2)
val = val_cache.as_strided(new_size, val_cache.stride(), storage_offset=0)
key.to(torch.float16).numpy().tofile(os.path.join(temp_dir, f"key_cache_{layer}.bin"))
val.to(torch.float16).numpy().tofile(os.path.join(temp_dir, f"value_cache_{layer}.bin"))

token = input_id.to(torch.int32).item()
output_tokens.append(torch.tensor([token]))
if streamer is not None:
streamer.put(torch.tensor([token]))

if "eos_token_id" not in new_generate_kwargs:
eos = 0xffffffff
else:
eos = new_generate_kwargs["eos_token_id"]
if "eos_token_id" not in new_generate_kwargs:
eos = 0xffffffff
else:
eos = new_generate_kwargs["eos_token_id"]

bdata = bdata + eos.to_bytes(4, sys.byteorder)
time_t1 = time.perf_counter()
idx += 1

time_start = time.perf_counter()
# start generate_serve by Thread
thread = threading.Thread(target=generate_serve,
args=(self.kv_len, self.num_head,
self.head_dim, self.num_layers,
self.transpose_value_cache,
new_tokens - 2))
thread.start()

input_pipe.write(bytearray(bdata))
input_pipe.flush()
in_pipe_path = "\\\\.\\pipe\\llminputpipe"
out_pipe_path = "\\\\.\\pipe\\llmoutputpipe"

buffersize = 4
output_tokens = []
while True:
data = output_pipe.read(buffersize)
if len(data) == 0:
break
token = int.from_bytes(data, sys.byteorder)
idx += 1
if time_t1 is None:
time_t1 = time.perf_counter()
output_tokens.append(torch.tensor([token]))
if streamer is not None:
streamer.put(torch.tensor([token]))
if token == eos:
break
while True:
try:
input_pipe = open(in_pipe_path, "wb")
except:
print('Waiting for input pipe')
time.sleep(1)
else:
break

output = torch.stack(output_tokens, dim=1)
output = torch.cat((inputs, output), dim=1)
if streamer is not None:
streamer.end()
while True:
try:
output_pipe = open(out_pipe_path, "rb")
except:
print('Waiting for output pipe')
time.sleep(1)
else:
break

time_start = time.perf_counter()

bdata = str.encode(str(temp_dir))
rnwang04 marked this conversation as resolved.
Show resolved Hide resolved
invalidInputError(len(bdata) <= 2000,
f"Input directory is too long ({len(bdata)}), which may cause read error.")
input_pipe.write(bdata)
input_pipe.flush()

buffersize = 4
while True:
data = output_pipe.read(buffersize)
if len(data) == 0:
break
token = int.from_bytes(data, sys.byteorder)
idx += 1
output_tokens.append(torch.tensor([token]))
if streamer is not None:
streamer.put(torch.tensor([token]))
if token == eos:
break

output = torch.stack(output_tokens, dim=1)
output = torch.cat((inputs, output), dim=1)
if streamer is not None:
streamer.end()

thread.join()
time_end = time.perf_counter()

if do_print:
print(f" Start the thread and connect the pipe time: {(time_start - time_start_all):.2f} s")
print(f" Start the thread and connect the pipe time: {(time_start - time_t1):.2f} s")
print(f" Number of input tokens: {input_length}")
print(f" Generated tokens: {idx}")
print(f" First token generation time: {(time_t1 - time_start):.2f} s")
print(f" Generation average latency: {(time_end - time_t1)*1000 /(idx - 1):.2f} ms, "
f"({(idx - 1)/(time_end - time_t1):.2f} token/s)")
print(f" Generation time: {(time_end - time_start):.2f} s\n")

print(f" First token generation time: {(time_t1 - time_start_all):.2f} s")
print(f" Generation average latency: {(time_end - time_start) * 1000 /(idx - 1):.2f} ms, "
f"({(idx - 1)/(time_end - time_start):.2f} token/s)")
print(f" Generation time: {(time_end - time_start_all - (time_start - time_t1)):.2f} s\n")
return output


Expand Down Expand Up @@ -182,8 +220,15 @@ def update_names_of_IR_and_export_blob(model, model_name, dir):

def convert_llm(model: torch.nn.Module,
kv_len: int,
max_prompt_len: int,
transpose_value_cache: bool):
if model.config.model_type == "llama":
from ipex_llm.transformers.npu_models.convert_mp import convert_llama
convert_llama(model,
max_output_len=kv_len,
max_prompt_len=max_prompt_len,
decoder=False,
transpose_value_cache=transpose_value_cache)
from .llama import LowBitLlamaLMHead, LlamaEmbedding
with tempfile.TemporaryDirectory() as temp_dir:
# generate lm_head blob
Expand Down Expand Up @@ -231,13 +276,12 @@ def convert_llm(model: torch.nn.Module,
new_embedding = LlamaEmbedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
dtype=np.float16,
)
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir)
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)

# generate decoder layer blob
from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self,
vocab_size,
embedding_dim,
embedding_weight,
padding_idx,
dtype, # fp16
device: str = "NPU",
Expand All @@ -91,7 +92,7 @@ def __init__(
self.dtype = dtype

# define input
weight = self.parameter((vocab_size, embedding_dim))
weight = self.constant(embedding_weight)
input = self.parameter((1, 1), dtype=np.int32)

if padding_idx == -1:
Expand Down
Loading