diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md index b4203eb6f58..35b9165cc6d 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md @@ -22,7 +22,7 @@ pip install mpi4py fastapi uvicorn openai pip install gradio # for gradio web UI conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc -pip install transformers==4.31.0 # for llama2 models +pip install transformers==4.37.0 ``` ### 2. Run pipeline parallel serving on multiple GPUs diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py index 5b32796c105..dab6146cbfa 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py @@ -30,7 +30,6 @@ def perform_request(session, url, payload, headers): start_time = time.perf_counter() with session.post(url, json=payload, headers=headers, stream=True) as response: response.raise_for_status() - first_token_time = None last_token_time = 0 first_token_inference_time = None @@ -38,21 +37,29 @@ def perform_request(session, url, payload, headers): next_token_time = [] i = 0 for line in response.iter_lines(): - token_time = time.perf_counter() - start_time if line: - data = line.decode("utf-8").strip() - i = i + 1 - try: - json_data = json.loads(data) - if json_data["message"] is not None: - if first_token_time is None: - first_token_time = token_time - else: - next_token_time.append(token_time - last_token_time) - last_token_time = token_time - except json.JSONDecodeError: - pass + data = line.decode('utf-8').strip() + if data.startswith('data: '): + data = data[len('data: '):] + i = i + 1 + try: + json_data = json.loads(data) + if 'choices' in json_data and len(json_data['choices']) > 0: + choice = json_data['choices'][0] + if 'finish_reason' in choice and (choice['finish_reason'] == 'length' or choice['finish_reason'] == 'stop'): + if 'first_token_time' in choice and isinstance(choice['first_token_time'], float): + first_token_inference_time = choice['first_token_time'] + if 'rest_token_time' in choice and isinstance(choice['rest_token_time'], float): + next_token_inference_time = choice['rest_token_time'] + else: + if first_token_time is None: + first_token_time = token_time + else: + next_token_time.append(token_time - last_token_time) + last_token_time = token_time + except json.JSONDecodeError: + pass end_time = time.perf_counter() return ( first_token_time, @@ -76,11 +83,11 @@ def extend_list_to_length(lst, target_length): def benchmark( llm_urls, prompt, + num_warmup_requests, num_requests, max_concurrent_requests, max_tokens, prompt_length, - is_warmup=False, ): headers = {"Content-Type": "application/json"} @@ -92,6 +99,8 @@ def benchmark( next_token_inference_times = [] cur_url_index = 0 + num_requests = num_requests + num_warmup_requests + with requests.Session() as session: with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: llm_url = llm_urls[cur_url_index] @@ -101,8 +110,17 @@ def benchmark( cur_len = len(cur_llm_urls) payload = { + "model": "Meta-Llama-3-8B-Instruct", "prompt": prompt, - "n_predict": max_tokens, + "max_tokens": max_tokens, + "stream": True, + # for vllm openai api server + "ignore_eos": True, + "n": 1, + "best_of": 1, + "use_beam_search": False, + "temperature": 0.0, + "top_p": 1.0, } futures = [ executor.submit( @@ -115,14 +133,13 @@ def benchmark( for index in range(num_requests) ] - start_time = time.perf_counter() + phase = "Benchmarking" - if is_warmup: - phase = "Warm Up" - else: - phase = "Benchmarking" with tqdm(total=num_requests, desc=phase, unit="req", ncols=100) as pbar: + cur_index = 0 for future in concurrent.futures.as_completed(futures): + if cur_index == num_warmup_requests: + start_time = time.perf_counter() try: ( first_token_latency, @@ -131,21 +148,21 @@ def benchmark( first_token_inference_time, next_token_inference_time, ) = future.result() - first_token_latencies.append(first_token_latency) - next_token_latencies.append(next_token_latency) - total_responce_times.append(total_responce_time) - if first_token_inference_time: - first_token_inference_times.append( - first_token_inference_time - ) - if next_token_inference_time: - next_token_inference_times.append(next_token_inference_time) + cur_index = cur_index + 1 + if cur_index > num_warmup_requests: + first_token_latencies.append(first_token_latency) + next_token_latencies.append(next_token_latency) + total_responce_times.append(total_responce_time) + if first_token_inference_time: + first_token_inference_times.append( + first_token_inference_time + ) + if next_token_inference_time: + next_token_inference_times.append(next_token_inference_time) except Exception as e: print(f"Request failed: {e}") pbar.update(1) - if is_warmup: - return total_time = time.perf_counter() - start_time log_file = f"{max_concurrent_requests}.log" @@ -174,9 +191,6 @@ def benchmark( ) p90_first_token_latency = np.percentile(first_token_latencies, 90) p95_first_token_latency = np.percentile(first_token_latencies, 95) - # average_first_token_inference_latency = np.mean( - # first_token_inference_times - # ) print( f"Average first token latency: {average_first_token_latency * 1000} milliseconds.", file=file, @@ -189,10 +203,6 @@ def benchmark( f"P95 first token latency: {p95_first_token_latency * 1000} milliseconds.", file=file, ) - # print( - # f"Average first token inference latency: {average_first_token_inference_latency * 1000} milliseconds.", - # file=file, - # ) print(file=file) if next_token_latencies: @@ -201,9 +211,6 @@ def benchmark( ) p90_next_token_latency = np.percentile(next_token_latencies, 90) p95_next_token_latency = np.percentile(next_token_latencies, 95) - # average_next_token_inference_latency = np.mean( - # next_token_inference_times - # ) print( f"Average next token latency: {average_next_token_latency * 1000} milliseconds.", file=file, @@ -216,14 +223,10 @@ def benchmark( f"P95 next token latency: {p95_next_token_latency * 1000} milliseconds.", file=file, ) - # print( - # f"Average next token inference latency: {average_next_token_inference_latency * 1000} milliseconds.", - # file=file, - # ) print(file=file) -LLM_URLS = [f"http://localhost:{PORT}/generate_stream/" for PORT in [8000]] +LLM_URLS = [f"http://localhost:{PORT}/v1/completions" for PORT in [8000]] parser = argparse.ArgumentParser(description="Set prompt length.") parser.add_argument( @@ -254,17 +257,6 @@ def benchmark( for MAX_CONCURRENT_REQUESTS in args.max_concurrent_requests: NUM_WARMUP = 5 * MAX_CONCURRENT_REQUESTS - NUM_REQUESTS = 10 * MAX_CONCURRENT_REQUESTS - - # warm up - benchmark( - LLM_URLS, - PROMPT, - NUM_WARMUP, - MAX_CONCURRENT_REQUESTS, - MAX_TOKENS, - PROMPT_LENGTH, - is_warmup=True, - ) + NUM_REQUESTS = 30 * MAX_CONCURRENT_REQUESTS - benchmark(LLM_URLS, PROMPT, NUM_REQUESTS, MAX_CONCURRENT_REQUESTS, MAX_TOKENS, PROMPT_LENGTH) + benchmark(LLM_URLS, PROMPT, NUM_WARMUP, NUM_REQUESTS, MAX_CONCURRENT_REQUESTS, MAX_TOKENS, PROMPT_LENGTH) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py deleted file mode 100644 index 29bac7a70b7..00000000000 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py +++ /dev/null @@ -1,328 +0,0 @@ -from transformers.modeling_utils import PreTrainedModel -from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaRMSNorm, LlamaPreTrainedModel -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast - -from torch import nn -import torch -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from typing import List, Optional, Tuple, Union, Iterator -from transformers.utils import logging -logger = logging.get_logger(__name__) -import numpy as np -import time -from transformers import AutoTokenizer, AutoConfig -import torch.distributed as dist -from pipeline_models import ( - _make_causal_mask, _expand_mask, DummyLayer, PPConfig, - PipelineBaseModel, -) - - -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.config = config - - # pp modification - self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size()) - nr_slices = self.pp_config.pp_world_size - # self.config.num_hidden_layers = 8 - slice_size = (self.config.num_hidden_layers + nr_slices - - 1) // nr_slices - self.layer_start = slice_size * self.pp_config.pp_rank - self.layer_end = self.layer_start + min(slice_size, - self.config.num_hidden_layers - self.layer_start) - self.num_layers = self.layer_end - self.layer_start - layers = [] - for i in range(self.config.num_hidden_layers): - if i < self.layer_start or i >= self.layer_end: - layers.append(DummyLayer()) - else: - layers.append(LlamaDecoderLayer(config)) - self.layers = nn.ModuleList(layers) - - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - if self.pp_config.is_head: - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - if self.pp_config.is_tail: - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds for pp - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - assert self.pp_config.is_head, "input_ids is only supported on the head stage" - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage" - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx in range(self.num_layers): - decoder_layer = self.layers[self.layer_start + idx] - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if self.pp_config.is_tail: - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - - def __init__(self, config: LlamaConfig): - super().__init__(config=config) - self.config = config - self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size()) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - if self.pp_config.is_tail: - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if self.pp_config.is_tail: - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - return outputs - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py index f8f5ba77ca8..790f5ab8fb3 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py @@ -1,15 +1,15 @@ -from torch import nn import torch import torch.distributed as dist from typing import List, Optional, Tuple, Union, Iterator import time -from transformers import AutoTokenizer, AutoConfig +from transformers.cache_utils import Cache from transformers.utils import logging -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + import numpy as np import asyncio, uuid import threading +from pydantic import BaseModel logger = logging.get_logger(__name__) @@ -23,227 +23,15 @@ def __init__(self, pp_rank: int, pp_world_size: int) -> None: self.is_head = self.pp_rank == 0 self.is_tail = self.pp_rank == self.pp_world_size - 1 -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - -class DummyLayer(nn.Module): - pass - - -class PipelineBaseModel(nn.Module): - def __init__(self, config): - self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size()) - nr_slices = self.pp_config.pp_world_size - # self.config.num_hidden_layers = 8 - slice_size = (self.config.num_hidden_layers + nr_slices - - 1) // nr_slices - self.layer_start = slice_size * self.pp_config.pp_rank - self.layer_end = self.layer_start + min(slice_size, - self.config.num_hidden_layers - self.layer_start) - self.num_layers = self.layer_end - self.layer_start - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - assert self.pp_config.is_head, "input_ids is only supported on the head stage" - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage" - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx in range(self.num_layers): - decoder_layer = self.layers[self.layer_start + idx] - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if self.pp_config.is_tail: - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -def load_model(checkpoint): - from llama_models import LlamaForCausalLM - if 'llama' in checkpoint.lower(): - model = LlamaForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.float16) - return model - -from pydantic import BaseModel class BatchTask(BaseModel): batch_id: str request_ids: List[str] max_tokens: int batch_size: int input_len: int - # plain_texts: List[str] prompt_lengths: List[int] stopped: bool - # input_ids: torch.Tensor - # attention_mask: torch.Tensor def make_attention_mask(prompt_lengths): @@ -256,19 +44,14 @@ def make_attention_mask(prompt_lengths): class ModelRunner: def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs): - - import sys + self.pp_config = PPConfig(rank, world_size) start = time.perf_counter() - model = load_model(checkpoint) + model = self.load_model(checkpoint, rank, world_size, low_bit) end = time.perf_counter() logger.info(f"Time to load weights: {end - start:.2f}s") - from ipex_llm import optimize_model - model = optimize_model(model, low_bit=low_bit) - - model = model.to(torch.float16).to(f'xpu:{rank}') self.model = model self.rank = rank self.world_size = world_size @@ -295,44 +78,63 @@ def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs): self.is_finish = {} self.model_name = checkpoint - - # def generate(self, input_ids=None, max_tokens=5, attention_mask=None): - # times = [] - # with torch.no_grad(): - # _input_ids = None - # _past_key_values = None - # bs = input_ids.shape[0] - # output_ids = input_ids.clone() - # for i in range(max_tokens): - # start = time.perf_counter() - # if _input_ids is None: - # _input_ids = input_ids - # if self.rank == 0: - # outputs = self.model(input_ids=_input_ids, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True) - # else: - # inputs_embeds = torch.empty(_input_ids.shape + (self.hidden_size,) , device=f'xpu:{self.rank}', dtype=torch.float32) - # dist.recv(inputs_embeds, src=self.pre_rank) - # outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True) - - # if self.rank == self.world_size - 1: - # logits = outputs.logits - # next_ids = torch.argmax(logits[:, -1:, :], dim=-1) - # assert next_ids.shape == (bs, 1) - # dist.broadcast(next_ids, src=self.rank) - # else: - # dist.send(outputs.last_hidden_state, dst=self.next_rank) - # next_ids = torch.empty((bs, 1), device=f'xpu:{self.rank}', dtype=torch.int64) - # dist.broadcast(next_ids, src=self.world_size - 1) - - # _input_ids = next_ids - # output_ids = torch.cat([output_ids, next_ids], dim=-1) - # _past_key_values = outputs.past_key_values - # end = time.perf_counter() - # times.append(end - start) - - # if self.rank == 0: - # logger.info(f"first token latency: {times[0]}, rest token avg latecy: {np.mean(times[1:])}") - # return output_ids + self.layer_start = 0 + + + def load_model(self, model_path, my_rank, my_size, low_bit='sym_int4'): + device = f"xpu:{my_rank}" + from ipex_llm.transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_low_bit=low_bit, + torch_dtype=torch.float16, + optimize_model=True, + trust_remote_code=True, + use_cache=True, + pipeline_parallel_stages=my_size).eval() + # print(model) + + # config_class = type(model.config).__name__ + # if config_class == 'ChatGLMConfig': + # model.config.num_hidden_layers = model.config.num_layers + # nr_slices = my_size + # slice_size = (model.config.num_layers + nr_slices - 1) // nr_slices + # layer_start = slice_size * my_rank + # layer_end = layer_start + min(slice_size, model.config.num_layers - layer_start) + + # for i in range(model.config.num_layers): + # if i < layer_start or i >= layer_end: + # model.transformer.encoder.layers[i] = Dummy_DecoderLayer() + # else: + # pass + # # align layer_idx and len(past_key_values), otherwise abnormal output + # # model._modules['encoder'].layers[i].self_attention.layer_idx = i - layer_start + # # model.transformer.encoder.layers[i].self_attention.layer_idx = i - layer_start + + # if my_rank != 0: + # model.transformer.embedding = DummyLayer() + # if my_rank != my_size - 1: + # model.transformer.output_layer = DummyLayer() + + # else: + # nr_slices = my_size + # slice_size = (model.config.num_hidden_layers + nr_slices - 1) // nr_slices + # layer_start = slice_size * my_rank + # layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start) + + # for i in range(model.config.num_hidden_layers): + # if i < layer_start or i >= layer_end: + # model._modules['model'].layers[i] = Dummy_DecoderLayer() + # else: + # # align layer_idx and len(past_key_values), otherwise abnormal output + # model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start + # if my_rank != 0: + # model._modules['model'].embed_tokens = DummyLayer() + # if my_rank != my_size - 1: + # model._modules['model'].norm = DummyLayer() + # model._modules['lm_head'] = DummyLayer() + + # model = model.to(f'xpu:{my_rank}') + return model def model_step(self, input, cur_batch): @@ -341,7 +143,6 @@ def model_step(self, input, cur_batch): cur_id = cur_batch.batch_id _past_key_values = self.past_key_values_dict.get(cur_id, None) - # attention_mask = self.attention_mask_dict[cur_id] attention_mask = make_attention_mask(cur_batch.prompt_lengths) if self.rank == 0: @@ -350,18 +151,33 @@ def model_step(self, input, cur_batch): else: input_ids = None inputs_embeds = input + + # logger.info(f"{self.rank}, {_past_key_values}") output = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=_past_key_values, - use_cache=True + use_cache=True, + output_hidden_states=True, ) - self.past_key_values_dict[cur_id] = output.past_key_values + use_legacy_cache = not isinstance(output.past_key_values, Cache) + if use_legacy_cache and self.rank > 0: + if output.past_key_values[0] is None: + _past_key_values = list(output.past_key_values) + slice_size = (self.model.config.num_hidden_layers + self.world_size - 1) // self.world_size + layer_start = slice_size * self.rank + + _past_key_values[0] = [torch.empty_like(output.past_key_values[layer_start][0])] + _past_key_values = tuple(_past_key_values) + else: + _past_key_values = output.past_key_values + else: + _past_key_values = output.past_key_values + self.past_key_values_dict[cur_id] = _past_key_values if not self.pp_config.is_tail: - return output.last_hidden_state + return output.hidden_states[-1] else: - # logger.info(f"logits: {output.logits.shape}") return output.logits @@ -376,7 +192,6 @@ async def add_request(self, tokenizer): break tmp_result = await self.waiting_requests.get() - # logger.info(tmp_result) request_id, prompt_request = tmp_result request_ids.append(request_id) prompt_requests.append(prompt_request) @@ -393,14 +208,10 @@ async def add_request(self, tokenizer): input_len=input_ids.size(1), prompt_lengths=[sum(attention_mask[i,:]) for i in range(input_ids.size(0))], stopped=False, - # plain_texts=plain_texts, - # input_ids=input_ids, - # attention_mask=attention_mask, ) self.input_ids_dict[new_batch.batch_id] = input_ids self.token_times[new_batch.batch_id] = [time.perf_counter()] - # self.attention_mask_dict[new_batch.batch_id] = attention_mask return new_batch @@ -409,7 +220,6 @@ def clear_batch(self, cur_id): self.input_ids_dict.pop(cur_id, None) self.tokens.pop(cur_id, None) self.token_times.pop(cur_id, None) - # self.attention_mask_dict.pop(cur_id, None) self.past_key_values_dict.pop(cur_id, None) # torch.xpu.empty_cache() @@ -448,9 +258,7 @@ async def process_step(self, tokenizer, result_dict): next_ids = next_ids.unsqueeze(0) self.tokens[cur_id].append(next_ids) self.token_times[cur_id].append(time.perf_counter()) - # self.input_ids_dict[cur_id] += next_ids cur_input = next_ids - # cur_batch.input_len += 1 cur_batch.input_len = 1 cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] @@ -462,9 +270,10 @@ async def process_step(self, tokenizer, result_dict): if self.streamer.get(request_id, None) is None: self.streamer[request_id] = asyncio.Queue() - if next_ids[index].int() == tokenizer.eos_token_id: - remain = 0 - self.is_finish[request_id] = True + # Currently ignore eos for benchmark + # if next_ids[index].int() == tokenizer.eos_token_id: + # remain = 0 + # self.is_finish[request_id] = True if self.token_cache.get(request_id, None) is None: self.token_cache[request_id] = [] @@ -533,12 +342,6 @@ async def process_step(self, tokenizer, result_dict): cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype) # logger.info(f"rank: {self.rank}, recv: {cur_input.shape}") dist.recv(cur_input, src=self.pre_rank) - - # if self.attention_mask_dict.get(cur_batch.batch_id, None) is None: - # self.attention_mask_dict[cur_batch.batch_id] = make_attention_mask(cur_batch.prompt_lengths) - - # if self.rank == 0: - # logger.info(f"rank: {self.rank}, {batch_list}") output = self.model_step(cur_input, cur_batch) if output is not None and self.rank == self.world_size - 1: @@ -576,4 +379,4 @@ def _is_chinese_char(cp): ): # return True - return False \ No newline at end of file + return False diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py index aeaf0d1741a..bc0fb5c8d6f 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py @@ -3,19 +3,16 @@ import torch.distributed as dist import os -import ipex_llm from ipex_llm.utils.common import invalidInputError +from ipex_llm.transformers import init_pipeline_parallel import oneccl_bindings_for_pytorch import json from transformers.utils import logging logger = logging.get_logger(__name__) -os.environ['MASTER_ADDR'] = '127.0.0.1' -os.environ['MASTER_PORT'] = '29501' +init_pipeline_parallel() -backend = 'ccl' -dist.init_process_group(backend) my_rank = dist.get_rank() my_size = dist.get_world_size() device = f"xpu:{my_rank}" @@ -146,7 +143,7 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id) if remain == 0: choice_data = CompletionResponseStreamChoice( index=index, - text=None, + text="", logprobs=None, finish_reason="length") chunk = CompletionStreamResponse( @@ -171,7 +168,6 @@ async def generator(local_model, delta_text_queue, request_id): break else: await asyncio.sleep(0) - # streamer_dict.pop(request_id, None) local_model.streamer.pop(request_id, None) @@ -282,29 +278,6 @@ async def create_completion(request: CompletionRequest): return result -def generate_text(prompt: List[str], n_predict = 32): - while prompt[-1] == "": - prompt = prompt[:-1] - if isinstance(n_predict, list): - n_predict = max(n_predict) - - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs.input_ids.to(f'xpu:{local_rank}') - print(inputs) - attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}') - output = local_model.generate(input_ids, - max_tokens=n_predict, - # attention_mask=attention_mask, - # max_new_tokens=n_predict, - # min_new_tokens=n_predict, - # do_sample=False, - # use_cache=True - ) - torch.xpu.synchronize() - - return output - - async def process_requests(local_model, result_dict): while True: await asyncio.sleep(0) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh index 3c6243d613e..02995299cb6 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh @@ -14,6 +14,6 @@ export TORCH_LLM_ALLREDUCE=0 export MODEL_PATH=YOUR_MODEL_PATH export NUM_GPUS=2 -export BIGDL_QUANTIZE_KV_CACHE=1 +export IPEX_LLM_QUANTIZE_KV_CACHE=1 CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4 diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index d346d8aa578..4e070e411d8 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -64,7 +64,9 @@ def __init__(self, *args): self.input_layernorm = DummyLayer() self.mlp = Dummy_MLPLayer() - def forward(self, hidden_states, past_key_value=None, use_cache=False, **kwargs): + def forward(self, hidden_states, *args, **kwargs): + past_key_value = kwargs.get('past_key_value', None) + use_cache = kwargs.get('use_cache', False) outputs = (hidden_states,) if use_cache: outputs += (past_key_value,)