diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 91d23797c26..14c4b8a1c74 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -19,13 +19,17 @@ import torch from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F import torch.distributed as dist import os import time import numpy as np -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union, Tuple from types import SimpleNamespace +import transformers from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ipex_llm.utils.common import invalidInputError from ipex_llm.ggml.quantize import ggml_tensor_qtype import logging @@ -107,6 +111,34 @@ def init_pipeline_parallel(): dist.init_process_group('ccl') +def low_mem_convert(model): + from ipex_llm.transformers.convert import convert_forward + import importlib + if 'llama' in model.config.model_type: + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaForCausalLM, + llama_causallm_forward_4_37_lowmem) + elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"): + if model.config.num_layers == 40: + # for glm4-9b + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward( + model, + module.ChatGLMForConditionalGeneration, + glm4_conditional_generation_forward_lowmem) + else: + # for chatglm3-6b + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward( + model, + module.ChatGLMForConditionalGeneration, + chatglm3_conditional_generation_forward_lowmem) + return model + + def _check_quantize_kv_cache(model, idx, batch_size): # align use_quantize_kv_cache setting for different GPU in pipeline parallel pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \ @@ -186,6 +218,11 @@ def pipeline_parallel(model, pipeline_parallel_stages): model._modules['model'].norm = DummyLayer() model._modules['lm_head'] = DummyLayer() + _enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM') + _enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1") + if _enable_lowmem: + model = low_mem_convert(model) + model.pipeline_parallel_stages = pipeline_parallel_stages model.layer_start = layer_start model.layer_end = layer_end @@ -867,3 +904,208 @@ def _is_chinese_char(cp): return True return False + + +def llama_causallm_forward_4_37_lowmem( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + # ipex-llm change starts + + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) # noqa + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa + logits = torch.cat(logits, dim=-1) + else: + torch.xpu.empty_cache() + logits = self.lm_head(hidden_states) + torch.xpu.empty_cache() + # logits = logits.float() + + # ipex-llm change ends + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def chatglm3_conditional_generation_forward_lowmem( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, +): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + + # ipex-llm change starts + torch.xpu.empty_cache() + lm_logits = self.transformer.output_layer(hidden_states) + torch.xpu.empty_cache() + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + # lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + # ipex-llm change ends + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +def glm4_conditional_generation_forward_lowmem( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, +): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[:, -1:] + # ipex-llm change starts + torch.xpu.empty_cache() + lm_logits = self.transformer.output_layer(hidden_states) + torch.xpu.empty_cache() + + loss = None + if labels is not None: + # lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + # ipex-llm change ends + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + )