Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support llava 1.5 #1217

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions swift/llm/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ def _generate_stream():
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request) -> ChatCompletionResponse:
global _args
assert _args is not None
if request.stop is None:
request.stop = []
if _args.infer_backend == 'vllm':
return await inference_vllm_async(request, raw_request)
else:
Expand All @@ -520,6 +522,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
async def create_completion(request: CompletionRequest, raw_request: Request) -> CompletionResponse:
global _args
assert _args is not None
if request.stop is None:
request.stop = []
if _args.infer_backend == 'vllm':
return await inference_vllm_async(request, raw_request)
else:
Expand Down
8 changes: 5 additions & 3 deletions swift/llm/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,27 @@ def llm_export(args: ExportArguments) -> None:
logger.info(f'args: {args}')
seed_everything(args.seed)
if args.to_peft_format:
assert args.sft_type == 'lora'
assert args.sft_type == 'lora', f'args.sft_type: {args.sft_type}'
args.ckpt_dir = swift_to_peft_format(args.ckpt_dir)
if args.merge_lora:
merge_lora(args, device_map=args.merge_device_map)
if args.quant_bits > 0:
_args = args
assert args.quantization_bit == 0
assert args.quantization_bit == 0, f'args.quantization_bit: {args.quantization_bit}'
assert args.sft_type == 'full', 'you need to merge lora'
if args.quant_method == 'awq':
from awq import AutoAWQForCausalLM
model, template = prepare_model_template(
args, device_map=args.quant_device_map, verbose=False, automodel_class=AutoAWQForCausalLM)
awq_model_quantize(model, template.tokenizer)
model.save_quantized(args.quant_output_dir)
else: # gptq
elif args.quant_method == 'gptq':
model, template = prepare_model_template(args, device_map=args.quant_device_map, verbose=False)
gptq_quantizer = gptq_model_quantize(model, template.tokenizer)
model.config.quantization_config.pop('dataset', None)
gptq_quantizer.save(model, args.quant_output_dir)
else:
raise ValueError(f'args.quant_method: {args.quant_method}')

logger.info(get_model_info(model))
show_layers(model)
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,7 +2256,7 @@ def get_dataset(
assert model_name is not None and model_author is not None
dataset = _preprocess_self_cognition_dataset(dataset, model_name, model_author)

def _reduce_column(row):
def _reduce_column(row: Dict[str, Any]) -> Dict[str, Any]:
res = {}
if 'query' in row and isinstance(row['query'], (list, tuple)):
res['query'] = np.random.choice(row['query'])
Expand Down
1 change: 0 additions & 1 deletion swift/llm/utils/media.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import shutil
import time
from typing import Any, Dict, List, Literal, Optional, Union

import numpy as np
Expand Down
20 changes: 20 additions & 0 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class ModelType:
atom_7b = 'atom-7b'
atom_7b_chat = 'atom-7b-chat'
# llava
llava1_5_7b_chat = 'llava1_5-7b-chat'
llava1_6_mistral_7b_instruct = 'llava1_6-mistral-7b-instruct'
llava1_6_yi_34b_instruct = 'llava1_6-yi-34b-instruct'
llama3_llava_next_8b = 'llama3-llava-next-8b'
Expand Down Expand Up @@ -4645,6 +4646,25 @@ def _new_generate(inputs=None, *args, **kwargs):
model.generate = _new_generate


@register_model(
ModelType.llava1_5_7b_chat,
'huangjintao/llava-1.5-7b-hf',
LoRATM.llama,
TemplateType.llava1_5,
eos_token='</s>',
support_flash_attn=True,
requires=['transformers>=4.36'],
tags=['multi-modal', 'vision'],
hf_model_id='llava-hf/llava-1.5-7b-hf')
def get_model_tokenizer_llava1_5(model_dir: str, *args, **kwargs):
from transformers import AutoProcessor, LlavaForConditionalGeneration
processor = AutoProcessor.from_pretrained(model_dir)
model, tokenizer = get_model_tokenizer_with_flash_attn(
model_dir, *args, automodel_class=LlavaForConditionalGeneration, **kwargs)
tokenizer.processor = processor
return model, tokenizer


@register_model(
ModelType.llava1_6_yi_34b_instruct,
'AI-ModelScope/llava-v1.6-34b',
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class XRequestConfig:

n: int = 1
seed: Optional[int] = None
stop: List[str] = field(default_factory=list)
stop: Optional[List[str]] = None
stream: bool = False

best_of: Optional[int] = None
Expand Down
77 changes: 48 additions & 29 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class TemplateType:
chatglm3 = 'chatglm3'
llama = 'llama' # llama2
llama3 = 'llama3'
llava1_5 = 'llava1-5'
llava_mistral_instruct = 'llava-mistral-instruct'
llava_yi_instruct = 'llava-yi-instruct'
llava_llama_instruct = 'llava-llama-instruct'
Expand Down Expand Up @@ -639,6 +640,11 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
res['inputs_embeds'] = inputs_embeds
else:
res['input_ids'] = input_ids
# multimodal
pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]
if len(pixel_values) > 0:
res['pixel_values'] = torch.concat(pixel_values)

if loss_scale is not None:
res['loss_scale'] = loss_scale
return res
Expand Down Expand Up @@ -726,7 +732,7 @@ def register_template(template_type: str, template: Template, *, exist_ok: bool

register_template(
TemplateType.default,
Template([], ['### Human:\n', '{{QUERY}}\n\n', '### Assistant:\n'], ['\n\n'], [['eos_token_id']], DEFAULT_SYSTEM,
Template([], ['### Human:\n{{QUERY}}\n\n### Assistant:\n'], ['\n\n'], [['eos_token_id']], DEFAULT_SYSTEM,
['{{SYSTEM}}\n\n']))


Expand Down Expand Up @@ -930,7 +936,7 @@ def _init_template(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs) ->
class GLM4VTemplate(GLMTemplate):

def __init__(self):
super().__init__([], ['<|user|>\n', '{{QUERY}}<|assistant|>'], [], ['<|endoftext|>'], None,
super().__init__([], ['<|user|>\n{{QUERY}}<|assistant|>'], [], ['<|endoftext|>'], None,
['<|system|>\n{{SYSTEM}}'])

def check_example(self, example):
Expand Down Expand Up @@ -982,7 +988,7 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =

register_template(
TemplateType.yi_vl,
YiVLTemplate([], ['### Human: ', '{{QUERY}}\n### Assistant:'], ['\n'], ['\n###'], yi_vl_default_system,
YiVLTemplate([], ['### Human: {{QUERY}}\n### Assistant:'], ['\n'], ['\n###'], yi_vl_default_system,
['{{SYSTEM}}\n\n']),
use_model=True,
infer_media_type='round',
Expand Down Expand Up @@ -1202,8 +1208,8 @@ class InternvlTemplate(Template):
num_image_token = 256

def __init__(self):
super().__init__(['<s>'], ['<|im_start|>user\n', '{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'],
['<|im_end|>'], self.system, ['<|im_start|>system\n{{SYSTEM}}'])
super().__init__(['<s>'], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'],
['<|im_end|>'], self.system, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>'])

def check_example(self, example):
images = example.get('images') or []
Expand Down Expand Up @@ -1250,10 +1256,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
assert all('pixel_values' in b for b in batch), 'Temporarily, Interval only supports data with images'
pixel_values = [b['pixel_values'] for b in batch if 'pixel_values' in b]
image_flags = [b['image_flags'] for b in batch if 'image_flags' in b]
if pixel_values:
res['pixel_values'] = torch.concat(pixel_values)
if image_flags:
res['image_flags'] = torch.concat(image_flags)
return res
Expand All @@ -1267,8 +1270,8 @@ class InternvlPhi3Template(InternvlTemplate):
system = 'You are an AI assistant whose name is Phi-3.'

def __init__(self):
Template.__init__(self, ['<s>'], ['<|user|>\n', [-100], '{{QUERY}}<|end|>\n<|assistant|>\n'], ['<|end|>\n'],
['<|end|>'], self.system, ['<s><|system|>\n{{SYSTEM}}<|end|>\n'])
Template.__init__(self, ['<s>'], ['<|user|>\n{{QUERY}}<|end|>\n<|assistant|>\n'], ['<|end|>\n'], ['<|end|>'],
self.system, ['<s><|system|>\n{{SYSTEM}}<|end|>\n'])


register_template(
Expand Down Expand Up @@ -1320,6 +1323,34 @@ def __init__(self):
'and other non-computer science questions, you will refuse to answer\n')))


class Llava1_5Template(Template):

def __init__(self):
super().__init__(['<s>'], ['USER: {{QUERY}}\nASSISTANT:'], ['\n'], ['</s>'])

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
assert media_type == 'image'
return ['<image>\n']

def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
images_path = example.get('images') or []
images = []
for image_path in images_path:
image = _read_from_path(image_path)
images.append(image)
image_processor = self.tokenizer.processor.image_processor
if images:
inputs['pixel_values'] = image_processor(images, return_tensors='pt')['pixel_values'].to(self.model.dtype)
return inputs, {}


register_template(
TemplateType.llava1_5, Llava1_5Template(), use_model=True, infer_media_type='round', lazy_tokenize=True)


class LLavaTemplate(Template):

def __init__(self):
Expand Down Expand Up @@ -1387,8 +1418,8 @@ def __init__(self):


class LLavaLlamaTemplate(Template):
llavallama_query_template = '<|start_header_id|>user<|end_header_id|>\n\n' \
'{{QUERY}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
llavallama_query_template = ('<|start_header_id|>user<|end_header_id|>\n\n'
'{{QUERY}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n')

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example):
return ['<image>\n']
Expand All @@ -1407,13 +1438,6 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
inputs['pixel_values'] = pixel_values.to(self.model.dtype)
return inputs, {}

def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
pixel_values = [b['pixel_values'] for b in batch if 'pixel_values' in b]
if pixel_values:
res['pixel_values'] = torch.concat(pixel_values)
return res


register_template(
TemplateType.llava_llama_instruct,
Expand Down Expand Up @@ -1456,9 +1480,6 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any

def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
pixel_values = [b['pixel_values'] for b in batch if 'pixel_values' in b]
if pixel_values:
res['pixel_values'] = torch.concat(pixel_values)
token_type_ids = [torch.tensor(b['token_type_ids']) for b in batch]
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=0)
res['token_type_ids'] = token_type_ids
Expand Down Expand Up @@ -1519,9 +1540,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any

def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
pixel_values = [b['pixel_values'] for b in batch if 'pixel_values' in b]
if pixel_values:
res['pixel_values'] = torch.concat(pixel_values)
if 'pixel_values' in res:
res['image_sizes'] = torch.concat([b['image_sizes'] for b in batch if 'image_sizes' in b])
return res

Expand Down Expand Up @@ -1554,7 +1573,7 @@ class LLavaQwenTemplate(LLavaTemplate):
llavayi_query_template = 'You are a helpful assistant'

def __init__(self):
Template.__init__(self, [], ['<|im_start|>user\n', '{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
Template.__init__(self, [], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
['<|im_end|>\n'], ['<|im_end|>'], self.llavayi_query_template,
['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n'])

Expand Down Expand Up @@ -1827,7 +1846,7 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:

register_template(
TemplateType.minicpm_v,
MiniCPMVTemplate(['<s>{{SYSTEM}}'], ['<用户>', '{{QUERY}}<AI>'], [], ['</s>']),
MiniCPMVTemplate(['<s>{{SYSTEM}}'], ['<用户>{{QUERY}}<AI>'], [], ['</s>']),
use_model=True,
lazy_tokenize=True,
infer_media_type='dialogue',
Expand All @@ -1837,7 +1856,7 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
register_template(
TemplateType.minicpm_v_v2_5,
MiniCPMVTemplate(['<|begin_of_text|>{{SYSTEM}}'], [
'<|start_header_id|>user<|end_header_id|>\n\n', '{{QUERY}}<|eot_id|>'
'<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
'<|start_header_id|>assistant<|end_header_id|>\n\n'
], ['<|eot_id|>'], ['<|eot_id|>'],
is_v2_5=True),
Expand Down Expand Up @@ -1893,7 +1912,7 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
class mPlugOwl2Template(Template):

def __init__(self):
super().__init__(['{{SYSTEM}}'], ['USER: ', '{{QUERY}}ASSISTANT:'], ['</s>'], [['eos_token_id']])
super().__init__(['{{SYSTEM}}'], ['USER: {{QUERY}}ASSISTANT:'], ['</s>'], [['eos_token_id']])

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
assert media_type == 'image'
Expand Down
3 changes: 0 additions & 3 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ def download_dataset(model_id: str, files: List[str], force_download: bool = Fal
def _msdataset_ddp_load(*args, **kwargs):
with safe_ddp_context():
dataset = _old_msdataset_load(*args, **kwargs)

if is_dist(): # sync
dist.barrier()
return dataset

# monkey patching
Expand Down
2 changes: 2 additions & 0 deletions swift/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def safe_ddp_context():
yield
if is_dist() and is_local_master():
dist.barrier()
if is_dist(): # sync
dist.barrier()


def check_json_format(obj: Any) -> Any:
Expand Down
Loading