Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: Add XPU Memory Optimizations for Pipeline Parallel #11567

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 243 additions & 1 deletion python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.distributed as dist
import os
import time
import numpy as np
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union, Tuple
from types import SimpleNamespace
import transformers
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ipex_llm.utils.common import invalidInputError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
import logging
Expand Down Expand Up @@ -107,6 +111,34 @@ def init_pipeline_parallel():
dist.init_process_group('ccl')


def low_mem_convert(model):
from ipex_llm.transformers.convert import convert_forward
import importlib
if 'llama' in model.config.model_type:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaForCausalLM,
llama_causallm_forward_4_37_lowmem)
elif model.config.model_type == "chatglm" and not hasattr(model.config, "vision_config"):
if model.config.num_layers == 40:
# for glm4-9b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(
model,
module.ChatGLMForConditionalGeneration,
glm4_conditional_generation_forward_lowmem)
else:
# for chatglm3-6b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(
model,
module.ChatGLMForConditionalGeneration,
chatglm3_conditional_generation_forward_lowmem)
return model


def _check_quantize_kv_cache(model, idx, batch_size):
# align use_quantize_kv_cache setting for different GPU in pipeline parallel
pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
Expand Down Expand Up @@ -186,6 +218,11 @@ def pipeline_parallel(model, pipeline_parallel_stages):
model._modules['model'].norm = DummyLayer()
model._modules['lm_head'] = DummyLayer()

_enable_lowmem = os.getenv('IPEX_LLM_LOW_MEM')
_enable_lowmem = (_enable_lowmem is not None) and (_enable_lowmem.lower() == "1")
if _enable_lowmem:
model = low_mem_convert(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this memory optimization could also be added in single GPU inference case? Other LGTM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could, but for single GPU current existing optimizations may be enough for the demands, so we don't override the CausalLM forward previously.


model.pipeline_parallel_stages = pipeline_parallel_stages
model.layer_start = layer_start
model.layer_end = layer_end
Expand Down Expand Up @@ -867,3 +904,208 @@ def _is_chinese_char(cp):
return True

return False


def llama_causallm_forward_4_37_lowmem(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]

# ipex-llm change starts

if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) # noqa
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa
logits = torch.cat(logits, dim=-1)
else:
torch.xpu.empty_cache()
logits = self.lm_head(hidden_states)
torch.xpu.empty_cache()
# logits = logits.float()

# ipex-llm change ends

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


def chatglm3_conditional_generation_forward_lowmem(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[-1:]

# ipex-llm change starts
torch.xpu.empty_cache()
lm_logits = self.transformer.output_layer(hidden_states)
torch.xpu.empty_cache()
lm_logits = lm_logits.transpose(0, 1).contiguous()

loss = None
if labels is not None:
# lm_logits = lm_logits.to(torch.float32)

# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
# ipex-llm change ends

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


def glm4_conditional_generation_forward_lowmem(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[:, -1:]
# ipex-llm change starts
torch.xpu.empty_cache()
lm_logits = self.transformer.output_layer(hidden_states)
torch.xpu.empty_cache()

loss = None
if labels is not None:
# lm_logits = lm_logits.to(torch.float32)

# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
# ipex-llm change ends

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
Loading