From 3da568212af5a454631c92cdc77e504f5b656fd9 Mon Sep 17 00:00:00 2001 From: MqLeet Date: Fri, 22 Nov 2024 15:40:26 +0800 Subject: [PATCH 1/3] [Add] GOT OCR 2.0 inference pipeline --- paddlemix/examples/GOT_OCR_2_0/README.md | 59 ++ .../GOT_OCR_2_0/configs/demo_dataset.json | 6 + .../examples/GOT_OCR_2_0/got_ocr2_0_infer.py | 91 ++ paddlemix/models/GOT/__init__.py | 13 + paddlemix/models/GOT/data/__init__.py | 121 +++ paddlemix/models/GOT/data/base_dataset.py | 82 ++ .../GOT/data/conversation_dataset_qwen.py | 329 +++++++ paddlemix/models/GOT/model/GOT_ocr_2_0.py | 835 ++++++++++++++++++ paddlemix/models/GOT/model/__init__.py | 15 + .../models/GOT/model/plug/blip_process.py | 413 +++++++++ .../GOT/model/vision_encoder/__init__.py | 13 + .../GOT/model/vision_encoder/got_vision_b.py | 490 ++++++++++ .../models/GOT/model/vision_encoder/vary_b.py | 487 ++++++++++ paddlemix/models/GOT/utils/constants.py | 33 + paddlemix/models/GOT/utils/conversation.py | 474 ++++++++++ paddlemix/models/GOT/utils/utils.py | 253 ++++++ 16 files changed, 3714 insertions(+) create mode 100644 paddlemix/examples/GOT_OCR_2_0/README.md create mode 100644 paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json create mode 100644 paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py create mode 100644 paddlemix/models/GOT/__init__.py create mode 100644 paddlemix/models/GOT/data/__init__.py create mode 100644 paddlemix/models/GOT/data/base_dataset.py create mode 100644 paddlemix/models/GOT/data/conversation_dataset_qwen.py create mode 100644 paddlemix/models/GOT/model/GOT_ocr_2_0.py create mode 100644 paddlemix/models/GOT/model/__init__.py create mode 100644 paddlemix/models/GOT/model/plug/blip_process.py create mode 100644 paddlemix/models/GOT/model/vision_encoder/__init__.py create mode 100644 paddlemix/models/GOT/model/vision_encoder/got_vision_b.py create mode 100644 paddlemix/models/GOT/model/vision_encoder/vary_b.py create mode 100644 paddlemix/models/GOT/utils/constants.py create mode 100644 paddlemix/models/GOT/utils/conversation.py create mode 100644 paddlemix/models/GOT/utils/utils.py diff --git a/paddlemix/examples/GOT_OCR_2_0/README.md b/paddlemix/examples/GOT_OCR_2_0/README.md new file mode 100644 index 000000000..e995efb70 --- /dev/null +++ b/paddlemix/examples/GOT_OCR_2_0/README.md @@ -0,0 +1,59 @@ +# GOT-OCR2.0 + +## 1. 模型介绍 + +[GOT-OCR2.0](https://qwenlm.github.io/blog/qwen2-vl/) 是大规模视觉语言模型。可以以图像、文本、检测框、视频作为输入,并以文本和检测框作为输出。本仓库提供paddle版本的`GOT-OCR2.0`模型。 + + +## 2 环境准备 +- **python >= 3.10** +- **paddlepaddle-gpu 要求版本develop** +``` +# 安装示例 +python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html +``` + +- paddlenlp >= 3.0.0(默认开启flash_attn,推荐源码编译安装) + +> 注: +* 请确保安装了以上依赖,否则无法运行。同时,需要安装 paddlemix/external_ops 下的自定义OP, `python setup.py install`。如果安装后仍然找不到算子,需要额外设置PYTHONPATH + +## 3 推理预测 + +1. plain texts OCR: +```bash +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type ocr +``` + +2. format texts OCR: +```bash +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format +``` + +3. fine-grained OCR: +```bash +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format/ocr --box [x1,y1,x2,y2] +``` +```bash +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format/ocr --color red/green/blue +``` + +4. multi-crop OCR: +```bash +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --multi_crop --ocr_type format/ocr +``` + +4. render the formatted OCR results: +```bash +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format --render +``` + +## 参考文献 +```BibTeX +@article{wei2024general, + title={General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model}, + author={Wei, Haoran and Liu, Chenglong and Chen, Jinyue and Wang, Jia and Kong, Lingyu and Xu, Yanming and Ge, Zheng and Zhao, Liang and Sun, Jianjian and Peng, Yuang and others}, + journal={arXiv preprint arXiv:2409.01704}, + year={2024} +} +``` diff --git a/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json b/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json new file mode 100644 index 000000000..e728195a8 --- /dev/null +++ b/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json @@ -0,0 +1,6 @@ +{ + "synthdog_en": { + "images": "playground/data/synthdog-en/", + "annotations": "playground/opensource/synthdog_en.jsonl" + } +} diff --git a/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py b/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py new file mode 100644 index 000000000..cba95bb7c --- /dev/null +++ b/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import paddle +from paddlenlp.transformers import QWenTokenizer + +from paddlemix.models.GOT.model import GOTQwenForCausalLM + +parser = argparse.ArgumentParser() + +parser.add_argument("--model_name_or_path", type=str, default="GOT-OCR2_0_pd", help="pretrained ckpt and tokenizer") +parser.add_argument("--image_file", type=str, default="yiyuan.jpeg") +parser.add_argument("--multi_crop", action="store_true") +parser.add_argument("--ocr_type", type=str, default="plain", choices=["ocr", "format"]) +parser.add_argument("--box", type=str, default="") +parser.add_argument("--color", type=str, default="") +parser.add_argument("--render", action="store_true") + +args = parser.parse_args() +model_name_or_path = args.model_name_or_path + +tokenizer = QWenTokenizer.from_pretrained(model_name_or_path) +# print('tokenizer:\n', tokenizer) +# print('tokenizer.added_tokens_encoder:\n', tokenizer.added_tokens_encoder) +# print('tokenizer.added_tokens_decoder:\n', tokenizer.added_tokens_decoder) +# PretrainedTokenizer(name_or_path='', +# vocab_size=151851, model_max_len=8000, padding_side='right', +# truncation_side='right', special_tokens={ +# 'pad_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False)}) +model = GOTQwenForCausalLM.from_pretrained( + model_name_or_path, dtype=paddle.bfloat16, pad_token_id=tokenizer.eos_token_id +).eval() +# print('tokenizer:\n', tokenizer) + + +# input test image +image_file = args.image_file +with paddle.no_grad(): + if args.multi_crop: + # multi-crop OCR: + res = model.chat_crop( + tokenizer, image_file, ocr_type=args.ocr_type, render=args.render, save_render_file="./demo.html" + ) + else: + # plain texts OCR + # format texts OCR + # fine-grained OCR + # render the formatted OCR results + res = model.chat( + tokenizer, + image_file, + ocr_type=args.ocr_type, + ocr_box=args.box, + ocr_color=args.color, + render=args.render, + save_render_file="./demo.html", + ) + + # plain texts OCR + # res = model.chat(tokenizer, image_file, ocr_type='ocr') + + # format texts OCR: + # res = model.chat(tokenizer, image_file, ocr_type='format') + + # fine-grained OCR: + # res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_box='') + # res = model.chat(tokenizer, image_file, ocr_type='format', ocr_box='') + # res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_color='') + # res = model.chat(tokenizer, image_file, ocr_type='format', ocr_color='') + + # multi-crop OCR: + # res = model.chat_crop(tokenizer, image_file, ocr_type='ocr') + # res = model.chat_crop(tokenizer, image_file, ocr_type='format') + + # render the formatted OCR results: + # res = model.chat(tokenizer, image_file, ocr_type='format', render=True, save_render_file = './demo.html') + + print(res) diff --git a/paddlemix/models/GOT/__init__.py b/paddlemix/models/GOT/__init__.py new file mode 100644 index 000000000..fd05a9208 --- /dev/null +++ b/paddlemix/models/GOT/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlemix/models/GOT/data/__init__.py b/paddlemix/models/GOT/data/__init__.py new file mode 100644 index 000000000..061cac8b2 --- /dev/null +++ b/paddlemix/models/GOT/data/__init__.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from functools import partial +from sys import meta_path +from typing import List, Union + +import paddle +import paddlenlp +from paddle import Tensor + +from paddlemix.models.GOT.data.conversation_dataset_qwen import ConversationDataset + +from ..utils.constants import * + +IGNORE_INDEX = -100 + + +# helpers +def pad_sequence_paddle(sequences, padding_value=0): + """ + Implement a function similar to PyTorch's pad_sequence in PaddlePaddle. + + Args: + - sequences (list of Tensor): The list of sequences to be padded. + - padding_value (float, optional): The value used for padding, default is 0. + + Returns: + - Tensor: The result of padding all sequences to the same length. + """ + # Calculate the maximum length + max_len = max([seq.shape[0] for seq in sequences]) + + # Pad sequences + padded_sequences = [] + for seq in sequences: + # Calculate the length to pad + padding_len = max_len - seq.shape[0] + + # Create a padding tensor + if padding_len > 0: + padding_tensor = paddle.full([padding_len] + list(seq.shape[1:]), padding_value, dtype=seq.dtype) + # Concatenate the original sequence and the padding tensor + padded_seq = paddle.concat([seq, padding_tensor], axis=0) + else: + padded_seq = seq + + padded_sequences.append(padded_seq) + + # Stack the padded sequences to form a batch + padded_batch = paddle.stack(padded_sequences, axis=0) + return padded_batch + + +def orig_pad_sequence( + sequences: Union[Tensor, List[Tensor]], + batch_first: bool = False, + padding_value: float = 0.0, +) -> Tensor: + if batch_first: + return pad_sequence_paddle(sequences, padding_value) + else: + assert False, "Not implemented" + + +@dataclass +class DataCollatorForSupervisedDataset(object): + tokenizer: paddlenlp.transformers.PretrainedTokenizer + + def __call__(self, instances): + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + images = [paddle.stack(instance["image"]) for instance in instances] + images_high = [paddle.stack(instance["image_high"]) for instance in instances] + images = list(zip(images, images_high)) + + pad_sequence = partial(orig_pad_sequence, batch_first=True) + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) + + labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)), + images=images, + ) + return batch + + +def make_supervised_data_module(interleave, with_box, tokenizer, data_args): + assert data_args.conversation_version == "mpt" + + train_dataset = ConversationDataset( + tokenizer=tokenizer, + # datasets=data_args.datasets, + meta_path=data_args.meta_path, + multimodal_cfg=dict( + sep_image_conv_front=data_args.sep_image_conv_front, + image_token_len=data_args.image_token_len, + image_aspect_ratio=data_args.image_aspect_ratio, + use_im_start_end=data_args.use_im_start_end, + image_processor=data_args.image_processor, + image_processor_high=data_args.image_processor_high, + box_limit=data_args.box_limit, + ), + ) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) diff --git a/paddlemix/models/GOT/data/base_dataset.py b/paddlemix/models/GOT/data/base_dataset.py new file mode 100644 index 000000000..c8e2144c6 --- /dev/null +++ b/paddlemix/models/GOT/data/base_dataset.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# import copy +# import io +# import json +import logging + +# from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict + +import paddle +import paddlenlp +from paddle.io import Dataset +from PIL import ImageFile # , Image + +ImageFile.LOAD_TRUNCATED_IMAGES = True +# from ..utils.constants import * + + +class BaseDataset(Dataset): + def __init__(self, datasets: str, tokenizer: paddlenlp.transformers.PretrainedTokenizer, multimodal_cfg: dict): + super(BaseDataset, self).__init__() + self.tokenizer = tokenizer + self.multimodal_cfg = multimodal_cfg + + logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image") + + def image_processor(self, image): + # processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit) + processor_high = self.multimodal_cfg[ + "image_processor_high" + ] # the second processor, usually is the designed image encoder (sam/swin/cnn) + image_high = image.copy() + + # Vary old codes + + # # TODO the 'keep', 'padding' only used for the first processor + # if self.multimodal_cfg['image_aspect_ratio'] == 'keep': + # max_hw, min_hw = max(image.size), min(image.size) + # aspect_ratio = max_hw / min_hw + # max_len, min_len = 448, 224 + # shortest_edge = int(min(max_len / aspect_ratio, min_len)) + # image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0] + # elif self.multimodal_cfg['image_aspect_ratio'] == 'pad': + # def expand2square(pil_img, background_color): + # width, height = pil_img.size + # if width == height: + # return pil_img + # elif width > height: + # result = Image.new(pil_img.mode, (width, width), background_color) + # result.paste(pil_img) # for simpler box processing + # return result + # else: + # result = Image.new(pil_img.mode, (height, height), background_color) + # result.paste(pil_img) # for simpler box processing + # return result + # image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + # image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0] + # else: + # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + + image_high = processor_high(image_high) + + return image_high + + def __len__(self): + return len(self.list_data_dict) + + def __getitem__(self, i) -> Dict[str, paddle.Tensor]: + pass diff --git a/paddlemix/models/GOT/data/conversation_dataset_qwen.py b/paddlemix/models/GOT/data/conversation_dataset_qwen.py new file mode 100644 index 000000000..fd4240374 --- /dev/null +++ b/paddlemix/models/GOT/data/conversation_dataset_qwen.py @@ -0,0 +1,329 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +# import io +import json +import logging + +# import os +import random + +# from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict + +import paddle +from PIL import Image, ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +from megfile import smart_glob +from natsort import natsorted + +# from ..utils.constants import CONVERSATION_DATA +from ..utils.conversation import ( # conv_templates,; default_conversation, + SeparatorStyle, + conv_mpt, +) +from .base_dataset import BaseDataset + +IGNORE_INDEX = -100 +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "log" + +IGNORE_INDEX = -100 +# DEFAULT_PAD_TOKEN = "[PAD]" + +DEFAULT_PAD_TOKEN = "<|endoftext|>" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_BOX_TOKEN = "" + +DEFAULT_IMAGE_PATCH_TOKEN = "" + +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" + + +class ConversationDataset(BaseDataset): + """Conversation format dataset stage2 fine-tuning.""" + + def __init__(self, meta_path, tokenizer, multimodal_cfg): + super(ConversationDataset, self).__init__(meta_path, tokenizer, multimodal_cfg) + # v0 version format conversation + # default_conversation = conv_templates["mpt"] + logging.warning("Formatting inputs into conversation type: mpt-fixed") + logging.warning("Loading data...") + + list_data_dict = [] + list_image_path = [] + + # add your data [data1, data2, data3, .....] + # got_data_dict = { + # "pdf-ocr": ["data1"], + # #'scene-ocr': ["data3", "data4"] + # # ...... + # } + # for name_all in datasets.split("+"): + # for name in got_data_dict[name_all]: + ds_collections = json.loads(open(meta_path).read()) + for ds_idx, ds_name in enumerate(ds_collections.keys()): + if 1: + # dataset = CONVERSATION_DATA[ds_name] + dataset = ds_collections[ds_name] + + data_path = dataset["annotations"] + if data_path.endswith(".json"): + data = json.load(open(data_path, "r")) + elif data_path.endswith(".jsonl"): + with open(data_path, "r") as f: + data = f.readlines() + for ii in range(len(data)): + data[ii] = json.loads(data[ii]) + else: + raise ValueError(f"Unknown file extension: {data_path}") + + list_data_dict.extend(data) + + image_path = dataset["images"] # image_root + + list_image_path.extend([image_path] * len(data)) + + logging.warning(f"Data from {data_path} provide {len(data)} conversations.") + + assert len(list_data_dict) == len(list_image_path) + logging.warning(f"{len(list_data_dict)} conversations in total.") + a_new_list = list(zip(list_data_dict, list_image_path)) + random.shuffle(a_new_list) + list_data_dict_new, list_image_path_new = zip(*a_new_list) + self.list_data_dict = list_data_dict_new + self.list_image_path = list_image_path_new + + self.im_patch_token = 151859 + + self.im_start_token = 151857 + + self.im_end_token = 151858 + + def multimodal_processor(self, sources, flag_num_patches): + for source in sources: + if self.multimodal_cfg["sep_image_conv_front"]: + assert DEFAULT_IMAGE_TOKEN in source[0]["value"] + source[0]["value"] = source[0]["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() + source[0]["value"] = DEFAULT_IMAGE_TOKEN + conv_mpt.sep + conv_mpt.roles[0] + ": " + source[0]["value"] + + for sentence in source: + replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg["image_token_len"] * flag_num_patches + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + # sentence["value"] = str(sentence["value"]).replace('\qquad', '\quad') + sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token) + return sources + + def _tokenize_fn(self, strings): + """Tokenize a list of strings.""" + tokenized_list = [ + self.tokenizer( + text, + return_tensors="pd", + padding="longest", + max_length=self.tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + def _mask_targets(self, target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker.lower() == "human": + target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + def token_processor(self, sources, image_name): + conv = conv_mpt.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + input_ids = self.tokenizer( + conversations, + return_tensors="pd", + padding="longest", + max_length=self.tokenizer.model_max_length, + truncation=True, + ).input_ids + + # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + targets = input_ids.clone() + assert conv.sep_style == SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len = int(target.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt + cur_len = 0 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids) + # round_len = len(tokenizer_image_token(rou, self.tokenizer)) + len(tokenizer_image_token(conv.sep, self.tokenizer)) + # instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) + instruction_len = len(self.tokenizer(parts[0]).input_ids) + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < self.tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + print(image_name) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + def __getitem__(self, i) -> Dict[str, paddle.Tensor]: + # data = self.list_data_dict[i] + data = copy.deepcopy(self.list_data_dict[i]) + + if isinstance(data, dict): + image_list = [] + image_high_list = [] + flag_num_patches = 1 + if "image" in data: + image_path = self.list_image_path[i] + image_file = data["image"] + + # multi-crop or multi page, only support .png files + if ( + 0 + ): # ('.jpg' not in image_file and '.png' not in image_file and '.jpeg' not in image_file) and ('.jpg' not in image_path and '.png' not in image_path and '.jpeg' not in image_path): + if image_file[0] == "/": + patch_dir = image_path[:-1] + image_file + patches = smart_glob(patch_dir + "*.png") + else: + patch_dir = image_path + image_file + patches = smart_glob(patch_dir + "*.png") + + # print(patches) + if not patches: + print(f"cannot glob the dir {patch_dir}.") + return self.__getitem__(0) + + # sort multi images by name + patches = natsorted(patches) + flag_num_patches = len(patches) + + for patch in patches: + try: + image = Image.open(patch).convert("RGB") + except: + print(f"cannot identify image file {patch}.") + return self.__getitem__(0) + + try: + img = self.image_processor(image) + image_list.append(img) + image_high_list.append(img) + except: + print( + f"image {image_path + image_file + patch} are broken or grayscale! we thus select 0-th sample instead!" + ) + return self.__getitem__(0) + + else: + flag_num_patches = 1 + try: + image = Image.open(image_path + image_file).convert("RGB") + except: + print(f"cannot identify image file {image_file}.") + return self.__getitem__(0) + + try: + image = self.image_processor(image) + except: + print(f"image {image_file} are broken or grayscale! we thus select 0-th sample instead!") + return self.__getitem__(0) + + conversations = self.multimodal_processor([data["conversations"]], flag_num_patches) + # print(conversations) + # exit() + else: + conversations = [data] + + # align with fastchat & llava here, put the conversation into a list for tokenization + image_name = image_path + image_file + data_dict = self.token_processor(conversations, image_name) + data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) + + if isinstance(data, dict) and "image" in data: + if image_list and image_high_list: + data_dict["image"] = image_list + data_dict["image_high"] = image_high_list + else: + data_dict["image"] = [image] + data_dict["image_high"] = [image] + else: + # crop_size = self.multimodal_cfg['image_processor'].crop_size + # data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])] + # Vary for two image, GOT does not use the data_dict['image] + data_dict["image"] = [paddle.zeros([3, 1024, 1024])] + data_dict["image_high"] = [paddle.zeros([3, 1024, 1024])] + return data_dict diff --git a/paddlemix/models/GOT/model/GOT_ocr_2_0.py b/paddlemix/models/GOT/model/GOT_ocr_2_0.py new file mode 100644 index 000000000..4bfb2cad2 --- /dev/null +++ b/paddlemix/models/GOT/model/GOT_ocr_2_0.py @@ -0,0 +1,835 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BytesIO +from typing import List, Optional # , Tuple, Union + +import paddle +import paddle.nn as nn + +# import paddle.nn.functional as F +# import paddlenlp +import requests +from paddlenlp.generation.stopping_criteria import ( # , TextStreamer; StoppingCriteria, + StoppingCriteriaList, +) +from paddlenlp.transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Model +from paddlenlp.transformers.model_outputs import CausalLMOutputWithPast +from PIL import Image + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" + +import dataclasses + +from paddle.vision import transforms + +from .plug.blip_process import BlipImageEvalProcessor +from .vision_encoder.got_vision_b import build_GOT_vit_b + + +class Qwen2LMHead(nn.Layer): + def __init__(self, config, embedding_weights=None, transpose_y=False, tensor_parallel_output=1): + super(Qwen2LMHead, self).__init__() + self.config = config + vocab_size = config.vocab_size + + self.transpose_y = transpose_y + if transpose_y: + # only for weight from embedding_weights + if embedding_weights is not None: + self.weight = embedding_weights + else: + self.weight = self.create_parameter( + shape=[vocab_size, config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + else: + # for weight from model init + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + + def forward(self, hidden_states, tensor_parallel_output=1): + logits = paddle.matmul(hidden_states, self.weight, transpose_y=self.transpose_y) + return logits + + +from enum import Enum, auto + + +class SeparatorStyle(Enum): + """Different separator style.""" + + SINGLE = auto() + TWO = auto() + MPT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "<|im_end|>" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + "\n" + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + if self.sep_style == SeparatorStyle.MPT: + if self.system: + ret = self.system + self.sep + else: + ret = "" + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + ) + + +class KeywordsStoppingCriteria(StoppingCriteriaList): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] + self.keyword_ids = [ + keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1 + ] + self.tokenizer = tokenizer + self.start_len = None + self.input_ids = input_ids + + def __call__(self, output_ids: paddle.Tensor, scores: paddle.Tensor, **kwargs) -> bool: + if self.start_len is None: + self.start_len = self.input_ids.shape[1] + else: + for keyword_id in self.keyword_ids: + if output_ids[0, -1] == keyword_id: + return True + outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len :], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + +class GOTImageEvalProcessor: + def __init__(self, image_size=384, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean, std) + + self.transform = transforms.Compose( + [ + transforms.Resize((image_size, image_size), interpolation="bicubic"), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + +class GOTConfig(Qwen2Config): + model_type = "GOT" + + +class GOTQwenModel(Qwen2Model): + config_class = GOTConfig + + def __init__(self, config: Qwen2Config): + super(GOTQwenModel, self).__init__(config) + self.vision_tower_high = build_GOT_vit_b() + self.mm_projector_vary = nn.Linear(1024, 1024) + + def initialize_vision_modules( + self, + vision_tower, + pretrained_stage1_model=None, + freeze_vision_tower=False, + use_im_start_end=False, + vision_select_layer=-1, + dtype=paddle.float16, + ): + # Vary old codes, not use in GOT + image_processor = BlipImageEvalProcessor(image_size=1024) + # 1024*1024 + + image_processor_high = BlipImageEvalProcessor(image_size=1024) + + self.vision_tower_high = self.vision_tower_high.to(dtype=dtype) + + self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype) + + image_token_len = 256 + + self.config.vision_tower = vision_tower + self.config.image_token_len = image_token_len + + self.config.use_im_start_end = True + + self.config.vision_select_layer = vision_select_layer + self.config.freeze_vision_tower = freeze_vision_tower + + return dict( + image_processor=image_processor, + image_processor_high=image_processor_high, + image_token_len=image_token_len, + ) + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[paddle.Tensor] = None, + return_dict: Optional[bool] = None, + ): + # HACK: replace back original embeddings for LLaVA pretraining + orig_embeds_params = getattr(self, "orig_embeds_params", None) + if orig_embeds_params is not None: + with paddle.no_grad(): + self.get_input_embeddings().weight[: -self.num_new_tokens] = orig_embeds_params[ + : -self.num_new_tokens + ].data + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + vision_tower_high = getattr(self, "vision_tower_high", None) + + if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: + use_im_start_end = getattr(self.config, "use_im_start_end", -1) + + # vision_select_layer = getattr(self.config, "vision_select_layer", -1) + im_patch_token = getattr(self.config, "im_patch_token", -1) + im_start_token = getattr(self.config, "im_start_token", -1) + im_end_token = getattr(self.config, "im_end_token", -1) + # freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) + + im_patch_token = 151859 + im_start_token = 151857 + im_end_token = 151858 + + image_features = [] + + for image in images: + if self.training: + image = image[1] + P, C, H, W = image.shape + if P == 1: + with paddle.set_grad_enabled(False): + cnn_feature = vision_tower_high(image) + cnn_feature = cnn_feature.flatten(2).transpose([0, 2, 1]) # 256*1024 + image_feature = self.mm_projector_vary(cnn_feature) + image_features.append(image_feature) + + else: + image_patches = paddle.unbind(image) + image_patches_features = [] + for image_patch in image_patches: + image_p = paddle.stack([image_patch]) + with paddle.set_grad_enabled(False): + cnn_feature_p = vision_tower_high(image_p) + cnn_feature_p = cnn_feature_p.flatten(2).transpose([0, 2, 1]) + image_feature_p = self.mm_projector_vary(cnn_feature_p) + image_patches_features.append(image_feature_p) + image_feature = paddle.concat(image_patches_features, axis=1) + image_features.append(image_feature) + + dummy_image_features_2 = paddle.zeros([256, 1024], dtype=inputs_embeds.dtype) + # dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2) + dummy_image_features = dummy_image_features_2 + use_im_start_end = True + new_input_embeds = [] + for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): + if (cur_input_ids == im_patch_token).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = cur_input_embeds + (0.0 * dummy_image_features).sum() + new_input_embeds.append(cur_input_embeds) + continue + + if use_im_start_end: + if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): + raise ValueError("The number of image start tokens and image end tokens should be the same.") + + image_start_tokens = paddle.where(cur_input_ids == im_start_token)[0] + for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features): + num_patches = per_cur_image_features.shape[0] + + if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: + raise ValueError("The image end token should follow the image start token.") + + cur_input_embeds = paddle.concat( + ( + cur_input_embeds[: image_start_token_pos + 1], + per_cur_image_features, + cur_input_embeds[image_start_token_pos + num_patches + 1 :], + ), + axis=0, + ) + + new_input_embeds.append(cur_input_embeds) + else: + raise NotImplementedError + + inputs_embeds = paddle.stack(new_input_embeds, axis=0) + + return super().forward( + input_ids=None, + attention_mask=attention_mask, # [1, 1, 1, 800] + past_key_values=past_key_values, # None + inputs_embeds=inputs_embeds, # [1, 800, 1024] + use_cache=use_cache, # True + position_ids=position_ids, # [1, 1, 1, 800] + output_attentions=output_attentions, # False + output_hidden_states=output_hidden_states, # False + return_dict=return_dict, # False + ) + + +class GOTQwenForCausalLM(Qwen2ForCausalLM): + config_class = GOTConfig + + def __init__(self, config): + super(Qwen2ForCausalLM, self).__init__(config) + self.qwen2 = GOTQwenModel(config) + + self.vocab_size = config.vocab_size + # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False) + + if config.tie_word_embeddings: + self.lm_head = Qwen2LMHead(config, embedding_weights=self.qwen2.embed_tokens.weight, transpose_y=True) + self.tie_weights() + else: + self.lm_head = Qwen2LMHead(config) + + # Initialize weights and apply final processing + # self.post_init() + + def get_model(self): + return self.qwen2 + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[paddle.Tensor] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.qwen2( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + images=images, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.astype(dtype="float32") + + # logits + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + # Flatten the tokens + # loss_fct = nn.CrossEntropyLoss() + loss_fct = nn.CrossEntropyLoss(reduction="sum") + shift_logits = shift_logits.reshape([-1, self.config.vocab_size]) + shift_labels = shift_labels.reshape([-1]) + # Enable model parallelism + loss = loss_fct(shift_logits, shift_labels) + label_sum = paddle.sum(shift_labels != -100).cast("float32") + loss = loss / label_sum + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # input_ids [1, 287], past_key_values=None, attention_mask [1, 287], inputs_embeds=None + # kwargs ['images', 'use_cache', 'cache_position'] + # [1, 3, 1024, 1024], True, [0,,,,286] + + # input_ids [1, 288], past_key_values len(past_key_values)=24, attention_mask [1, 288], inputs_embeds=None + # kwargs ['images', 'use_cache', 'cache_position'] + # [1, 3, 1024, 1024], True, [287] + + batch_size, seq_length = input_ids.shape + attention_mask = paddle.ones((batch_size, seq_length), dtype=paddle.bool) + + # Omit tokens covered by past_key_values + if past_key_values is not None: + # if isinstance(past_key_values, Cache): ### + # cache_length = past_key_values.get_seq_length() + # past_length = past_key_values.seen_tokens + # max_cache_length = past_key_values.get_max_length() + # else: + past_length = past_key_values[0][0].shape[1] # [1, 800, 16, 64] + # max_cache_length = None + # cache_length = past_length + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + # if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + # input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # # input_ids based on the past_length. + # el + if past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + # if ( + # max_cache_length is not None + # and attention_mask is not None + # and cache_length + input_ids.shape[1] > max_cache_length + # ): + # attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.astype(dtype="int64").cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "images": kwargs.get("images", None), + } + ) + return model_inputs + + def initialize_vision_tokenizer( + self, + tokenizer, + freeze_lm_model=False, + pretrained_stage1_model=None, + ): + config = self.get_model().config + + self.resize_token_embeddings(len(tokenizer)) + + config.im_patch_token = 151859 + + config.use_im_start_end = True + + if config.use_im_start_end: + self.resize_token_embeddings(len(tokenizer)) + config.im_start_token, config.im_end_token = 151857, 151858 + + def load_image(self, image_file): + if image_file.startswith("http") or image_file.startswith("https"): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + image = Image.open(image_file).convert("RGB") + return image + + def chat( + self, + tokenizer, + image_file, + ocr_type, + ocr_box="", + ocr_color="", + render=False, + save_render_file=None, + print_prompt=False, + gradio_input=False, + stream_flag=False, + ): + + image_processor_high = GOTImageEvalProcessor(image_size=1024) + + use_im_start_end = True + + image_token_len = 256 + + if gradio_input: + image = image_file.copy() + else: + image = self.load_image(image_file) + + w, h = image.size + + if ocr_type == "format": + qs = "OCR with format: " + else: + qs = "OCR: " + + if ocr_box: + bbox = eval(ocr_box) + if len(bbox) == 2: + bbox[0] = int(bbox[0] / w * 1000) + bbox[1] = int(bbox[1] / h * 1000) + if len(bbox) == 4: + bbox[0] = int(bbox[0] / w * 1000) + bbox[1] = int(bbox[1] / h * 1000) + bbox[2] = int(bbox[2] / w * 1000) + bbox[3] = int(bbox[3] / h * 1000) + if ocr_type == "format": + qs = str(bbox) + " " + "OCR with format: " + else: + qs = str(bbox) + " " + "OCR: " + + if ocr_color: + if ocr_type == "format": + qs = "[" + ocr_color + "]" + " " + "OCR with format: " + else: + qs = "[" + ocr_color + "]" + " " + "OCR: " + + if use_im_start_end: + qs = ( + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + "\n" + qs + ) + else: + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + conv_mpt = Conversation( + system="""<|im_start|>system + You should follow the instructions carefully and explain your answers in detail.""", + # system = None, + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", + ) + + conv = conv_mpt.copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if print_prompt: + print("prompt", prompt) + + inputs = tokenizer([prompt]) + + image_tensor_1 = image_processor_high(image) + + input_ids = paddle.to_tensor(inputs.input_ids) + + # print('input_ids', input_ids.shape, input_ids.sum().item(), input_ids) + # [1, 287] + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + output_ids = self.generate( + input_ids, + images=[image_tensor_1.unsqueeze(0).cast(paddle.bfloat16)], + do_sample=False, + num_beams=1, + no_repeat_ngram_size=20, + max_new_tokens=4096, + stopping_criteria=stopping_criteria, # list of stopping criteria + )[0] + # print('output_ids:\n', output_ids.shape, output_ids.sum().item(), output_ids) + + # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() + outputs = tokenizer.decode(output_ids[0]).strip() + # print('outputs', outputs) + + if outputs.endswith(stop_str): + outputs = outputs[: -len(stop_str)] + outputs = outputs.strip() + response_str = outputs + return response_str + + def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True): + def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') + return best_ratio + + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + # print(target_ratios) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # print(target_aspect_ratio) + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + def chat_crop( + self, + tokenizer, + image_file, + ocr_type, + render=False, + save_render_file=None, + print_prompt=False, + gradio_input=False, + stream_flag=False, + ): + # Model + multi_page = False + + image_processor_high = GOTImageEvalProcessor(image_size=1024) + + use_im_start_end = True + + image_token_len = 256 + + image_list = [] + + # if len(image_file_list)>1: + # multi_page = True + + if multi_page: + qs = "OCR with format across multi pages: " + # only for png files + # import glob + # from natsort import natsorted + # patches = glob.glob(image_file + '/*png') + patches = image_file + # patches = natsorted(patches) + sub_images = [] + for sub_image in patches: + sub_images.append(self.load_image(sub_image)) + + ll = len(patches) + + else: + if ocr_type == "format": + qs = "OCR with format upon the patch reference: " + else: + qs = "OCR upon the patch reference: " + if gradio_input: + img = image_file.copy() + else: + img = self.load_image(image_file) + sub_images = self.dynamic_preprocess(img) + ll = len(sub_images) + + for image in sub_images: + image_tensor_1 = image_processor_high(image) + image_list.append(image_tensor_1) + + image_list = paddle.stack(image_list) + + print("====new images batch size======: \n", image_list.shape) + + if use_im_start_end: + qs = ( + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len * ll + + DEFAULT_IM_END_TOKEN + + "\n" + + qs + ) + else: + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + conv_mpt = Conversation( + system="""<|im_start|>system + You should follow the instructions carefully and explain your answers in detail.""", + # system = None, + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", + ) + + conv = conv_mpt.copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if print_prompt: + print(prompt) + + inputs = tokenizer([prompt]) + + input_ids = paddle.to_tensor(inputs.input_ids) + # print('input_ids', input_ids.shape, input_ids.sum().item(), input_ids) + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + output_ids = self.generate( + input_ids, + images=[image_list.cast(paddle.bfloat16)], + do_sample=False, + num_beams=1, + # no_repeat_ngram_size = 20, + max_new_tokens=4096, + stopping_criteria=stopping_criteria, + )[0] + + # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() + outputs = tokenizer.decode(output_ids[0]).strip() + + if outputs.endswith(stop_str): + outputs = outputs[: -len(stop_str)] + outputs = outputs.strip() + response_str = outputs + return response_str diff --git a/paddlemix/models/GOT/model/__init__.py b/paddlemix/models/GOT/model/__init__.py new file mode 100644 index 000000000..d6794b65d --- /dev/null +++ b/paddlemix/models/GOT/model/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .GOT_ocr_2_0 import GOTConfig, GOTQwenForCausalLM, GOTQwenModel diff --git a/paddlemix/models/GOT/model/plug/blip_process.py b/paddlemix/models/GOT/model/plug/blip_process.py new file mode 100644 index 000000000..6ba5fc558 --- /dev/null +++ b/paddlemix/models/GOT/model/plug/blip_process.py @@ -0,0 +1,413 @@ +import paddle + +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import cv2 +import numpy as np + +# from PIL import Image + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + +class BlipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + if mean is None: + mean = 0.48145466, 0.4578275, 0.40821073 + if std is None: + std = 0.26862954, 0.26130258, 0.27577711 + self.normalize = paddle.vision.transforms.Normalize(mean, std) + + +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + """ + same output as PIL.ImageOps.autocontrast + """ + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if tuple(low.shape)[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if tuple(high.shape)[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + """ + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + """ + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + """ + like PIL, rotate by degree, not radians + """ + H, W = tuple(img.shape)[0], tuple(img.shape)[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + """ + same output as PIL.ImageOps.posterize + """ + table = np.array([(el if el < thresh else 255 - el) for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + """ + same output as PIL.ImageEnhance.Color + """ + M = np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]) * factor + np.float32( + [[0.114], [0.587], [0.299]] + ) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = np.array([((el - mean) * factor + mean) for el in range(256)]).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def brightness_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + """ + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + """ + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = tuple(img.shape)[0], tuple(img.shape)[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = tuple(img.shape)[0], tuple(img.shape)[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = tuple(img.shape)[0], tuple(img.shape)[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def posterize_func(img, bits): + """ + same output as PIL.ImageOps.posterize + """ + out = np.bitwise_and(img, np.uint8(255 << 8 - bits)) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = tuple(img.shape)[0], tuple(img.shape)[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = tuple(img.shape)[0], tuple(img.shape)[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +def enhance_level_to_args(MAX_LEVEL): + def level_to_args(level): + return (level / MAX_LEVEL * 1.8 + 0.1,) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = level / MAX_LEVEL * 0.3 + if np.random.random() > 0.5: + level = -level + return level, replace_value + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = level / MAX_LEVEL * float(translate_const) + if np.random.random() > 0.5: + level = -level + return level, replace_value + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = int(level / MAX_LEVEL * cutout_const) + return level, replace_value + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int(level / MAX_LEVEL * 256) + return (level,) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int(level / MAX_LEVEL * 4) + return (level,) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = level / MAX_LEVEL * 30 + if np.random.random() < 0.5: + level = -level + return level, replace_value + + return level_to_args + + +func_dict = { + "Identity": identity_func, + "AutoContrast": autocontrast_func, + "Equalize": equalize_func, + "Rotate": rotate_func, + "Solarize": solarize_func, + "Color": color_func, + "Contrast": contrast_func, + "Brightness": brightness_func, + "Sharpness": sharpness_func, + "ShearX": shear_x_func, + "TranslateX": translate_x_func, + "TranslateY": translate_y_func, + "Posterize": posterize_func, + "ShearY": shear_y_func, +} +translate_const = 10 +MAX_LEVEL = 10 +replace_value = 128, 128, 128 +arg_dict = { + "Identity": none_level_to_args, + "AutoContrast": none_level_to_args, + "Equalize": none_level_to_args, + "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), + "Solarize": solarize_level_to_args(MAX_LEVEL), + "Color": enhance_level_to_args(MAX_LEVEL), + "Contrast": enhance_level_to_args(MAX_LEVEL), + "Brightness": enhance_level_to_args(MAX_LEVEL), + "Sharpness": enhance_level_to_args(MAX_LEVEL), + "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), + "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "Posterize": posterize_level_to_args(MAX_LEVEL), + "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img + + +class VideoRandomAugment(object): + def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): + self.N = N + self.M = M + self.p = p + self.tensor_in_tensor_out = tensor_in_tensor_out + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N, replace=False) + return [(op, self.M) for op in sampled_ops] + + def __call__(self, frames): + assert tuple(frames.shape)[-1] == 3, "Expecting last dimension for 3-channels RGB (b, h, w, c)." + if self.tensor_in_tensor_out: + frames = frames.numpy().astype(np.uint8) + num_frames = tuple(frames.shape)[0] + ops = num_frames * [self.get_random_ops()] + apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] + frames = paddle.stack(x=list(map(self._aug, frames, ops, apply_or_not)), axis=0).astype(dtype="float32") + return frames + + def _aug(self, img, ops, apply_or_not): + for i, (name, level) in enumerate(ops): + if not apply_or_not[i]: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return paddle.to_tensor(data=img) + + +class BlipImageTrainProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + self.transform = paddle.vision.transforms.Compose( + [ + paddle.vision.transforms.RandomResizedCrop( + image_size, scale=(min_scale, max_scale), interpolation="bicubic" + ), + RandomAugment(2, 5, isPIL=True, augs=["Identity", "Brightness", "Sharpness", "Equalize"]), + paddle.vision.transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + +class BlipImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=384, mean=None, std=None): + super().__init__(mean=mean, std=std) + self.transform = paddle.vision.transforms.Compose( + [ + paddle.vision.transforms.Resize((image_size, image_size), interpolation="bicubic"), + paddle.vision.transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) diff --git a/paddlemix/models/GOT/model/vision_encoder/__init__.py b/paddlemix/models/GOT/model/vision_encoder/__init__.py new file mode 100644 index 000000000..fd05a9208 --- /dev/null +++ b/paddlemix/models/GOT/model/vision_encoder/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlemix/models/GOT/model/vision_encoder/got_vision_b.py b/paddlemix/models/GOT/model/vision_encoder/got_vision_b.py new file mode 100644 index 000000000..38e25f495 --- /dev/null +++ b/paddlemix/models/GOT/model/vision_encoder/got_vision_b.py @@ -0,0 +1,490 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# import math +from functools import partial +from typing import Optional, Tuple, Type + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +# class Projector(paddle.nn.Layer): + +# def __init__( +# self, +# width: 256, +# n_queries: int = 256, +# output_dim: int = 4096, +# **kwargs +# ): +# super().__init__() + +# norm_layer = partial(paddle.nn.LayerNorm, epsilon=1e-06) +# self.attn_pool = Resampler(grid_size=int(math.sqrt(n_queries)), +# embed_dim=output_dim, num_heads=output_dim // 128, kv_dim=width, +# norm_layer=norm_layer) +# self.ln_post = norm_layer(output_dim) +# self.proj = paddle.base.framework.EagerParamBase.from_tensor(tensor +# =output_dim ** -0.5 * paddle.randn(shape=[output_dim, output_dim])) + +# def forward(self, x: paddle.Tensor): +# x = self.attn_pool(x) +# x = self.ln_post(x) +# x = x @ self.proj +# return x + + +class MLPBlock(paddle.nn.Layer): + def __init__(self, embedding_dim: int, mlp_dim: int, act: Type[paddle.nn.Layer] = paddle.nn.GELU) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +class LayerNorm2d(paddle.nn.Layer): + def __init__(self, num_channels: int, epsilon: float = 1e-06) -> None: + super().__init__() + self.weight = paddle.base.framework.EagerParamBase.from_tensor(tensor=paddle.ones(shape=num_channels)) + self.bias = paddle.base.framework.EagerParamBase.from_tensor(tensor=paddle.zeros(shape=num_channels)) + self.epsilon = epsilon + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + u = x.mean(axis=1, keepdim=True) + s = (x - u).pow(y=2).mean(axis=1, keepdim=True) + x = (x - u) / paddle.sqrt(x=s + self.epsilon) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class ImageEncoderViT(paddle.nn.Layer): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Layer] = nn.LayerNorm, + act_layer: Type[nn.Layer] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Layer): Normalization layer. + act_layer (nn.Layer): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[paddle.base.framework.EagerParamBase.from_tensor] = None + if use_abs_pos: + self.pos_embed = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.zeros(shape=[1, img_size // patch_size, img_size // patch_size, embed_dim]) + ) + + self.blocks = paddle.nn.LayerList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2D( + embed_dim, + out_chans, + kernel_size=1, + bias_attr=False, + ), + LayerNorm2d(out_chans), + nn.Conv2D( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias_attr=False, + ), + LayerNorm2d(out_chans), + ) + + self.net_2 = nn.Conv2D(256, 512, kernel_size=3, stride=2, padding=1, bias_attr=False) + self.net_3 = nn.Conv2D(512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + for blk in self.blocks: + x = blk(x) + x = self.neck(x.transpose([0, 3, 1, 2])) + x = self.net_2(x) + x = self.net_3(x) + return x + + +class Block(paddle.nn.Layer): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Layer] = nn.LayerNorm, + act_layer: Type[nn.Layer] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Layer): Normalization layer. + act_layer (nn.Layer): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + shortcut = x + # import pdb; pdb.set_trace() + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(paddle.nn.Layer): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert input_size is not None, "Input size must be provided if using relative positional encoding." + self.rel_pos_h = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.zeros(shape=[2 * input_size[0] - 1, head_dim]) + ) + self.rel_pos_w = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.zeros(shape=[2 * input_size[1] - 1, head_dim]) + ) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + B, H, W, _ = tuple(x.shape) + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape([B, H * W, 3, self.num_heads, -1]).transpose([2, 0, 3, 1, 4]) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape([3, B * self.num_heads, H * W, -1]).unbind(axis=0) + + attn = (q * self.scale) @ k.transpose([0, 2, 1]) # [-2, -1] + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = F.softmax(attn, axis=-1) + x = (attn @ v).reshape([B, self.num_heads, H, W, -1]).transpose([0, 2, 3, 1, 4]).reshape([B, H, W, -1]) + x = self.proj(x) + + return x + + +def window_partition(x: paddle.Tensor, window_size: int) -> Tuple[paddle.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = tuple(x.shape) + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + # x.shape [1, 64, 64, 768] + if pad_h > 0 or pad_w > 0: + # x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) # torch + x = F.pad(x, pad=(0, pad_w, 0, pad_h), data_format="NHWC") # default NCHW + Hp, Wp = H + pad_h, W + pad_w + + x = x.reshape([B, Hp // window_size, window_size, Wp // window_size, window_size, C]) + windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C]) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: paddle.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> paddle.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = tuple(windows.shape)[0] // (Hp * Wp // window_size // window_size) + x = windows.reshape([B, Hp // window_size, Wp // window_size, window_size, window_size, -1]) + x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, -1]) + if Hp > H or Wp > W: + x = x[:, :H, :W, :] + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: paddle.Tensor) -> paddle.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + if tuple(rel_pos.shape)[0] != max_rel_dist: + rel_pos_resized = paddle.nn.functional.interpolate( + rel_pos.reshape([1, tuple(rel_pos.shape)[0], -1]).transpose([0, 2, 1]), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape([-1, max_rel_dist]).transpose([1, 0]) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = paddle.arange(end=q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = paddle.arange(end=k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = q_coords - k_coords + (k_size - 1) * max(q_size / k_size, 1.0) + return rel_pos_resized[relative_coords.astype(dtype="int64")] + + +def add_decomposed_rel_pos( + attn: paddle.Tensor, + q: paddle.Tensor, + rel_pos_h: paddle.Tensor, + rel_pos_w: paddle.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> paddle.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = tuple(q.shape) + r_q = q.reshape([B, q_h, q_w, dim]) + rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = (attn.reshape([B, q_h, q_w, k_h, k_w]) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).reshape( + [B, q_h * q_w, k_h * k_w] + ) + + return attn + + +class PatchEmbed(paddle.nn.Layer): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.transpose([0, 2, 3, 1]) + return x + + +def build_GOT_vit_b(checkpoint=None): + return _build_GOT_vision( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def _build_GOT_vision( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + # image_embedding_size = image_size // vit_patch_size + image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(paddle.nn.LayerNorm, epsilon=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + + return image_encoder diff --git a/paddlemix/models/GOT/model/vision_encoder/vary_b.py b/paddlemix/models/GOT/model/vision_encoder/vary_b.py new file mode 100644 index 000000000..213571535 --- /dev/null +++ b/paddlemix/models/GOT/model/vision_encoder/vary_b.py @@ -0,0 +1,487 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# import math +from functools import partial +from typing import Optional, Tuple, Type + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +# class Projector(paddle.nn.Layer): + +# def __init__( +# self, +# width: 256, +# n_queries: int = 256, +# output_dim: int = 4096, +# **kwargs +# ): +# super().__init__() + +# norm_layer = partial(paddle.nn.LayerNorm, epsilon=1e-06) +# self.attn_pool = Resampler(grid_size=int(math.sqrt(n_queries)), +# embed_dim=output_dim, num_heads=output_dim // 128, kv_dim=width, +# norm_layer=norm_layer) +# self.ln_post = norm_layer(output_dim) +# self.proj = paddle.base.framework.EagerParamBase.from_tensor(tensor +# =output_dim ** -0.5 * paddle.randn(shape=[output_dim, output_dim])) + +# def forward(self, x: paddle.Tensor): +# x = self.attn_pool(x) +# x = self.ln_post(x) +# x = x @ self.proj +# return x + + +class MLPBlock(paddle.nn.Layer): + def __init__(self, embedding_dim: int, mlp_dim: int, act: Type[paddle.nn.Layer] = paddle.nn.GELU) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +class LayerNorm2d(paddle.nn.Layer): + def __init__(self, num_channels: int, epsilon: float = 1e-06) -> None: + super().__init__() + self.weight = paddle.base.framework.EagerParamBase.from_tensor(tensor=paddle.ones(shape=num_channels)) + self.bias = paddle.base.framework.EagerParamBase.from_tensor(tensor=paddle.zeros(shape=num_channels)) + self.epsilon = epsilon + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + u = x.mean(axis=1, keepdim=True) + s = (x - u).pow(y=2).mean(axis=1, keepdim=True) + x = (x - u) / paddle.sqrt(x=s + self.epsilon) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class ImageEncoderViT(paddle.nn.Layer): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Layer] = nn.LayerNorm, + act_layer: Type[nn.Layer] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Layer): Normalization layer. + act_layer (nn.Layer): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[paddle.base.framework.EagerParamBase.from_tensor] = None + if use_abs_pos: + self.pos_embed = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.zeros(shape=[1, img_size // patch_size, img_size // patch_size, embed_dim]) + ) + + self.blocks = paddle.nn.LayerList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2D( + embed_dim, + out_chans, + kernel_size=1, + bias_attr=False, + ), + LayerNorm2d(out_chans), + nn.Conv2D( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias_attr=False, + ), + LayerNorm2d(out_chans), + ) + + self.net_2 = nn.Conv2D(256, 512, kernel_size=3, stride=2, padding=1, bias_attr=False) + self.net_3 = nn.Conv2D(512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + for blk in self.blocks: + x = blk(x) + x = self.neck(x.transpose([0, 3, 1, 2])) + x = self.net_2(x) + x = self.net_3(x) + return x + + +class Block(paddle.nn.Layer): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Layer] = nn.LayerNorm, + act_layer: Type[nn.Layer] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Layer): Normalization layer. + act_layer (nn.Layer): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(paddle.nn.Layer): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert input_size is not None, "Input size must be provided if using relative positional encoding." + self.rel_pos_h = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.zeros(shape=[2 * input_size[0] - 1, head_dim]) + ) + self.rel_pos_w = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.zeros(shape=[2 * input_size[1] - 1, head_dim]) + ) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + B, H, W, _ = tuple(x.shape) + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape([B, H * W, 3, self.num_heads, -1]).transpose([2, 0, 3, 1, 4]) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape([3, B * self.num_heads, H * W, -1]).unbind(axis=0) + + attn = q * self.scale @ k.transpose([0, 1, 3, 2]) # [-2, -1] + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = F.softmax(attn, axis=-1) + x = (attn @ v).reshape([B, self.num_heads, H, W, -1]).transpose([0, 2, 3, 1, 4]).reshape([B, H, W, -1]) + x = self.proj(x) + + return x + + +def window_partition(x: paddle.Tensor, window_size: int) -> Tuple[paddle.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = tuple(x.shape) + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, pad=(0, 0, 0, pad_w, 0, pad_h), pad_from_left_axis=False) + Hp, Wp = H + pad_h, W + pad_w + + x = x.reshape([B, Hp // window_size, window_size, Wp // window_size, window_size, C]) + windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C]) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: paddle.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> paddle.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = tuple(windows.shape)[0] // (Hp * Wp // window_size // window_size) + x = windows.reshape([B, Hp // window_size, Wp // window_size, window_size, window_size, -1]) + x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, -1]) + if Hp > H or Wp > W: + x = x[:, :H, :W, :] + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: paddle.Tensor) -> paddle.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + if tuple(rel_pos.shape)[0] != max_rel_dist: + rel_pos_resized = paddle.nn.functional.interpolate( + rel_pos.reshape([1, tuple(rel_pos.shape)[0], -1]).transpose([0, 2, 1]), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape([-1, max_rel_dist]).transpose([1, 0]) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = paddle.arange(end=q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = paddle.arange(end=k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = q_coords - k_coords + (k_size - 1) * max(q_size / k_size, 1.0) + return rel_pos_resized[relative_coords.astype(dtype="int64")] + + +def add_decomposed_rel_pos( + attn: paddle.Tensor, + q: paddle.Tensor, + rel_pos_h: paddle.Tensor, + rel_pos_w: paddle.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> paddle.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = tuple(q.shape) + r_q = q.reshape([B, q_h, q_w, dim]) + rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = (attn.reshape([B, q_h, q_w, k_h, k_w]) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).reshape( + [B, q_h * q_w, k_h * k_w] + ) + + return attn + + +class PatchEmbed(paddle.nn.Layer): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.transpose([0, 2, 3, 1]) + return x + + +def build_GOT_vit_b(checkpoint=None): + return _build_GOT_vision( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def _build_GOT_vision( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + # image_embedding_size = image_size // vit_patch_size + image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(paddle.nn.LayerNorm, epsilon=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + + return image_encoder diff --git a/paddlemix/models/GOT/utils/constants.py b/paddlemix/models/GOT/utils/constants.py new file mode 100644 index 000000000..5caa54e01 --- /dev/null +++ b/paddlemix/models/GOT/utils/constants.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "log" + +IGNORE_INDEX = -100 +# DEFAULT_PAD_TOKEN = "[PAD]" + +DEFAULT_PAD_TOKEN = "<|endoftext|>" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_BOX_TOKEN = "" + +DEFAULT_IMAGE_PATCH_TOKEN = "" + +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" diff --git a/paddlemix/models/GOT/utils/conversation.py b/paddlemix/models/GOT/utils/conversation.py new file mode 100644 index 000000000..1bb8a0d6c --- /dev/null +++ b/paddlemix/models/GOT/utils/conversation.py @@ -0,0 +1,474 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from enum import Enum, auto +from typing import List # , Tuple + + +class SeparatorStyle(Enum): + """Different separator style.""" + + SINGLE = auto() + TWO = auto() + MPT = auto() + + +# simple_conv_multimodal = Conversation( +# system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." +# "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." +# "Follow the instructions carefully and explain your answers in detail.", +# # system="", +# roles=("Human", "Assistant"), +# messages=( +# ("Human", "Hi!"), +# ("Assistant", "Hi there! How can I help you today?\n") +# ), +# offset=2, +# sep_style=SeparatorStyle.SINGLE, +# sep="###", +# ) + +# conv_mpt = Conversation( +# system="""<|im_start|>system +# - You are a helpful language and vision assistant. +# - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. +# - You should follow the instructions carefully and explain your answers in detail.""", +# roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), +# version="mpt", +# messages=(), +# offset=0, +# sep_style=SeparatorStyle.MPT, +# sep="<|im_end|>", +# ) + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "<|im_end|>" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + "\n" + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + if self.sep_style == SeparatorStyle.MPT: + if self.system: + ret = self.system + self.sep + else: + ret = "" + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + # if self.sep_style == SeparatorStyle.MPT: + # if self.system: + # ret = self.system + self.sep + # else: + # ret = '' + # for role, message in self.messages: + # if message: + # if type(message) is tuple: + # message, _, _ = message + # ret += role + message + self.sep + # # if 'user' in role: + # # ret += role + message + self.sep + "\n" + # # else: + # # ret += role + message + self.sep + # else: + # ret += role + # return ret + # else: + # raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + + from PIL import Image + + msg, image, image_process_mode = msg + if image_process_mode == "Pad": + + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + # result.paste(pil_img, (0, (width - height) // 2)) + result.paste(pil_img) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + # result.paste(pil_img, ((height - width) // 2, 0)) + result.paste(pil_img) + return result + + image = expand2square(image) + elif image_process_mode == "Crop": + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + elif image_process_mode == "Resize": + image = image.resize((224, 224)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + + if return_pil: + images.append(image) + else: + buffered = BytesIO() + image.convert("RGB").save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + images.append(img_b64_str) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + + msg, image, image_process_mode = msg + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + # image = image.resize((224, 224)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'user upload image' + msg = msg.replace("", img_str) + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + ) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Give three tips for staying healthy."), + ( + "Assistant", + "Sure, here are three tips for staying healthy:\n" + "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " + "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " + "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or " + "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening " + "activities at least two days per week.\n" + "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, " + "vegetables, whole grains, lean proteins, and healthy fats can help support " + "your overall health. Try to limit your intake of processed and high-sugar foods, " + "and aim to drink plenty of water throughout the day.\n" + "3. Get enough sleep: Getting enough quality sleep is essential for your physical " + "and mental health. Adults should aim for seven to nine hours of sleep per night. " + "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " + "help improve the quality of your sleep.", + ), + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_v1_2 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ( + "Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", + ), + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1_1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +# conv_mpt = Conversation( +# system="""<|im_start|>system +# - You are designed by Megvii(旷视), and your name is GOT. +# - 你叫GOT, 你来自旷视, 你是旷视开发的。 +# - 你擅长分析表格,仔细读图表中的内容,然后给出你的答案。""", +# roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), +# version="mpt", +# messages=(), +# offset=0, +# sep_style=SeparatorStyle.MPT, +# sep="<|im_end|>", +# ) + +conv_mpt = Conversation( + system="""<|im_start|>system +You should follow the instructions carefully and explain your answers in detail.""", + # system = None, + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_mpt_eval = Conversation( + system="", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_mpt_text = Conversation( + system="""<|im_start|>system +- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_bair_v1 = Conversation( + system="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +# simple_conv = Conversation( +# system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology, based on LLaMA architecture." +# "You are designed to assist human with a variety of tasks using natural language." +# "Follow the instructions carefully.", +# roles=("Human", "Assistant"), +# messages=( +# ("Human", "Hi!"), +# ("Assistant", "Hi there! How can I help you today?\n") +# ), +# offset=2, +# sep_style=SeparatorStyle.SINGLE, +# sep="###", +# ) + + +simple_conv = Conversation( + system="", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +simple_conv_multimodal = Conversation( + system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + # system="", + roles=("Human", "Assistant"), + messages=(("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?\n")), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +simple_conv_mpt_multimodal = Conversation( + system="""<|im_start|>system +- You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology. +- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. +- You should follow the instructions carefully and explain your answers in detail.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +simple_conv_legacy = Conversation( + system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology." + "You are designed to assist human with a variety of tasks using natural language." + "Follow the instructions carefully.", + roles=("Human", "Assistant"), + messages=(("Human", "Hi!\n\n### Response:"), ("Assistant", "Hi there! How can I help you today?\n")), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_llava_v1 = Conversation( + system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +default_conversation = conv_mpt +conv_templates = { + "default": simple_conv_multimodal, + "simple": simple_conv, + "simple_legacy": simple_conv_legacy, + "multimodal": simple_conv, + "mpt_multimodal": simple_conv_mpt_multimodal, + "llava_v1": conv_llava_v1, + "mpt_eval": conv_mpt_eval, + # fastchat + "v1": conv_vicuna_v1_1, + "bair_v1": conv_bair_v1, + "vicuna_v1_1": conv_vicuna_v1_1, + "mpt": conv_mpt, + "mpt_text": conv_mpt_text, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/paddlemix/models/GOT/utils/utils.py b/paddlemix/models/GOT/utils/utils.py new file mode 100644 index 000000000..8674e735c --- /dev/null +++ b/paddlemix/models/GOT/utils/utils.py @@ -0,0 +1,253 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# import datetime +# import logging + +import paddle + +# import requests +from paddlenlp.generation.stopping_criteria import ( # StoppingCriteriaList, + StoppingCriteria, +) + +# import logging.handlers +# import os +# import sys + + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + + +# def build_logger(logger_name, logger_filename): +# global handler + +# formatter = logging.Formatter( +# fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", +# datefmt="%Y-%m-%d %H:%M:%S", +# ) + +# # Set the format of root handlers +# if not logging.getLogger().handlers: +# logging.basicConfig(level=logging.INFO) +# logging.getLogger().handlers[0].setFormatter(formatter) + +# # Redirect stdout and stderr to loggers +# stdout_logger = logging.getLogger("stdout") +# stdout_logger.setLevel(logging.INFO) +# sl = StreamToLogger(stdout_logger, logging.INFO) +# sys.stdout = sl + +# stderr_logger = logging.getLogger("stderr") +# stderr_logger.setLevel(logging.ERROR) +# sl = StreamToLogger(stderr_logger, logging.ERROR) +# sys.stderr = sl + +# # Get logger +# logger = logging.getLogger(logger_name) +# logger.setLevel(logging.INFO) + +# # Add a file handler for all loggers +# if handler is None: +# os.makedirs(LOGDIR, exist_ok=True) +# filename = os.path.join(LOGDIR, logger_filename) +# handler = logging.handlers.TimedRotatingFileHandler( +# filename, when='D', utc=True) +# handler.setFormatter(formatter) + +# for name, item in logging.root.manager.loggerDict.items(): +# if isinstance(item, logging.Logger): +# item.addHandler(handler) + +# return logger + + +# class StreamToLogger(object): +# """ +# Fake file-like stream object that redirects writes to a logger instance. +# """ +# def __init__(self, logger, log_level=logging.INFO): +# self.terminal = sys.stdout +# self.logger = logger +# self.log_level = log_level +# self.linebuf = '' + +# def __getattr__(self, attr): +# return getattr(self.terminal, attr) + +# def write(self, buf): +# temp_linebuf = self.linebuf + buf +# self.linebuf = '' +# for line in temp_linebuf.splitlines(True): +# # From the io.TextIOWrapper docs: +# # On output, if newline is None, any '\n' characters written +# # are translated to the system default line separator. +# # By default sys.stdout.write() expects '\n' newlines and then +# # translates them so this is still cross platform. +# if line[-1] == '\n': +# self.logger.log(self.log_level, line.rstrip()) +# else: +# self.linebuf += line + +# def flush(self): +# if self.linebuf != '': +# self.logger.log(self.log_level, self.linebuf.rstrip()) +# self.linebuf = '' + + +# def disable_torch_init(): +# """ +# Disable the redundant torch default initialization to accelerate model creation. +# """ +# import torch +# setattr(torch.nn.Linear, "reset_parameters", lambda self: None) +# setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +# def violates_moderation(text): +# """ +# Check whether the text violates OpenAI moderation API. +# """ +# url = "https://api.openai.com/v1/moderations" +# headers = {"Content-Type": "application/json", +# "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} +# text = text.replace("\n", "") +# data = "{" + '"input": ' + f'"{text}"' + "}" +# data = data.encode("utf-8") +# try: +# ret = requests.post(url, headers=headers, data=data, timeout=5) +# flagged = ret.json()["results"][0]["flagged"] +# except requests.exceptions.RequestException as e: +# flagged = False +# except KeyError as e: +# flagged = False + +# return flagged + + +# def pretty_print_semaphore(semaphore): +# if semaphore is None: +# return "None" +# return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] + self.keyword_ids = [ + keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1 + ] + self.tokenizer = tokenizer + self.start_len = None + self.input_ids = input_ids + + def __call__(self, output_ids: paddle.Tensor, scores: paddle.Tensor, **kwargs) -> bool: + if self.start_len is None: + self.start_len = self.input_ids.shape[1] + else: + for keyword_id in self.keyword_ids: + if output_ids[0, -1] == keyword_id: + return True + outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len :], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + +def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + # num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + # # num_new_tokens = 1 + # # tokenizer.add_tokens(special_tokens_dict, special_tokens=True) + # model.resize_token_embeddings(len(tokenizer)) + + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight # .data + output_embeddings = model.get_output_embeddings().weight # .data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +# def maybe_zero_3(param, ignore_status=False, name=None): +# from deepspeed import zero +# from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +# if hasattr(param, "ds_id"): +# if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: +# if not ignore_status: +# logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") +# with zero.GatheredParameters([param]): +# param = param.data.detach().cpu().clone() +# else: +# param = param.detach().cpu().clone() +# return param + + +# # Borrowed from peft.utils.get_peft_model_state_dict +# def get_peft_state_maybe_zero_3(named_params, bias): +# if bias == "none": +# to_return = {k: t for k, t in named_params if "lora_" in k} +# elif bias == "all": +# to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} +# elif bias == "lora_only": +# to_return = {} +# maybe_lora_bias = {} +# lora_bias_names = set() +# for k, t in named_params: +# if "lora_" in k: +# to_return[k] = t +# bias_name = k.split("lora_")[0] + "bias" +# lora_bias_names.add(bias_name) +# elif "bias" in k: +# maybe_lora_bias[k] = t +# for k, t in maybe_lora_bias: +# if bias_name in lora_bias_names: +# to_return[bias_name] = t +# else: +# raise NotImplementedError +# to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} +# return to_return + + +# def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): +# to_return = {k: t for k, t in named_params if "lora_" not in k} +# if require_grad_only: +# to_return = {k: t for k, t in to_return.items() if t.requires_grad} +# to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} +# return to_return + + +# def find_all_linear_names(model): +# cls = torch.nn.Linear +# lora_module_names = set() +# for name, module in model.named_modules(): +# if isinstance(module, cls) and 'vision_model' not in name and 'mm_projector' not in name and 'vision_encoder' not in name and 'conv_final' not in name and'lm_head' not in name: +# lora_module_names.add(name) + +# print(lora_module_names) +# return list(lora_module_names) From f87d86b863c308e8bb20064df19f102ba7e98913 Mon Sep 17 00:00:00 2001 From: MqLeet Date: Mon, 16 Dec 2024 22:29:48 +0800 Subject: [PATCH 2/3] add GOT-OCR-2.0 train/inference code --- .../got_dataset.py} | 180 +++++-- paddlemix/demo_images/hospital.jpeg | Bin 0 -> 54871 bytes paddlemix/examples/GOT_OCR_2_0/README.md | 61 ++- .../GOT_OCR_2_0/configs/demo_dataset.json | 4 +- .../examples/GOT_OCR_2_0/got_ocr2_0_infer.py | 19 +- paddlemix/examples/GOT_OCR_2_0/run_train.sh | 78 +++ paddlemix/examples/GOT_OCR_2_0/train_GOT.py | 256 +++++++++ .../models/GOT/{model => }/GOT_ocr_2_0.py | 22 +- paddlemix/models/GOT/__init__.py | 2 + paddlemix/models/GOT/data/__init__.py | 121 ----- paddlemix/models/GOT/data/base_dataset.py | 82 --- .../vision_encoder => }/got_vision_b.py | 32 +- paddlemix/models/GOT/model/__init__.py | 15 - .../GOT/model/vision_encoder/__init__.py | 13 - .../models/GOT/model/vision_encoder/vary_b.py | 487 ------------------ paddlemix/models/GOT/utils/conversation.py | 65 +-- paddlemix/models/GOT/utils/utils.py | 126 +---- paddlemix/processors/__init__.py | 3 +- .../got_process.py} | 9 - 19 files changed, 547 insertions(+), 1028 deletions(-) rename paddlemix/{models/GOT/data/conversation_dataset_qwen.py => datasets/got_dataset.py} (68%) create mode 100644 paddlemix/demo_images/hospital.jpeg create mode 100644 paddlemix/examples/GOT_OCR_2_0/run_train.sh create mode 100644 paddlemix/examples/GOT_OCR_2_0/train_GOT.py rename paddlemix/models/GOT/{model => }/GOT_ocr_2_0.py (98%) delete mode 100644 paddlemix/models/GOT/data/__init__.py delete mode 100644 paddlemix/models/GOT/data/base_dataset.py rename paddlemix/models/GOT/{model/vision_encoder => }/got_vision_b.py (94%) delete mode 100644 paddlemix/models/GOT/model/__init__.py delete mode 100644 paddlemix/models/GOT/model/vision_encoder/__init__.py delete mode 100644 paddlemix/models/GOT/model/vision_encoder/vary_b.py rename paddlemix/{models/GOT/model/plug/blip_process.py => processors/got_process.py} (98%) diff --git a/paddlemix/models/GOT/data/conversation_dataset_qwen.py b/paddlemix/datasets/got_dataset.py similarity index 68% rename from paddlemix/models/GOT/data/conversation_dataset_qwen.py rename to paddlemix/datasets/got_dataset.py index fd4240374..f261d3b97 100644 --- a/paddlemix/models/GOT/data/conversation_dataset_qwen.py +++ b/paddlemix/datasets/got_dataset.py @@ -13,31 +13,25 @@ # limitations under the License. import copy - -# import io import json import logging - -# import os import random - -# from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict - import paddle +from paddle import Tensor +import paddlenlp from PIL import Image, ImageFile - ImageFile.LOAD_TRUNCATED_IMAGES = True - -from megfile import smart_glob -from natsort import natsorted - -# from ..utils.constants import CONVERSATION_DATA -from ..utils.conversation import ( # conv_templates,; default_conversation, +from ..models.GOT.utils.conversation import ( SeparatorStyle, conv_mpt, ) -from .base_dataset import BaseDataset +from dataclasses import dataclass +from functools import partial +from typing import List, Union +from megfile import smart_glob +from natsort import natsorted + IGNORE_INDEX = -100 CONTROLLER_HEART_BEAT_EXPIRATION = 30 @@ -61,6 +55,30 @@ DEFAULT_IM_END_TOKEN = "" +class BaseDataset(paddle.io.Dataset): + def __init__(self, datasets: str, tokenizer: paddlenlp.transformers.PretrainedTokenizer, multimodal_cfg: dict): + super(BaseDataset, self).__init__() + self.tokenizer = tokenizer + self.multimodal_cfg = multimodal_cfg + + logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image") + + def image_processor(self, image): + # processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit) + processor_high = self.multimodal_cfg[ + "image_processor_high" + ] # the second processor, usually is the designed image encoder (sam/swin/cnn) + image_high = image.copy() + image_high = processor_high(image_high) + return image_high + + def __len__(self): + return len(self.list_data_dict) + + def __getitem__(self, i) -> Dict[str, paddle.Tensor]: + pass + + class ConversationDataset(BaseDataset): """Conversation format dataset stage2 fine-tuning.""" @@ -83,29 +101,30 @@ def __init__(self, meta_path, tokenizer, multimodal_cfg): # for name_all in datasets.split("+"): # for name in got_data_dict[name_all]: ds_collections = json.loads(open(meta_path).read()) + #ds_collections = json.load(open(meta_path, 'r')) for ds_idx, ds_name in enumerate(ds_collections.keys()): - if 1: - # dataset = CONVERSATION_DATA[ds_name] - dataset = ds_collections[ds_name] - - data_path = dataset["annotations"] - if data_path.endswith(".json"): - data = json.load(open(data_path, "r")) - elif data_path.endswith(".jsonl"): - with open(data_path, "r") as f: - data = f.readlines() - for ii in range(len(data)): - data[ii] = json.loads(data[ii]) - else: - raise ValueError(f"Unknown file extension: {data_path}") + # dataset = CONVERSATION_DATA[ds_name] + dataset = ds_collections[ds_name] + + data_path = dataset["annotations"] + #image_root = dataset["images"] + if data_path.endswith(".json"): + data = json.load(open(data_path, "r")) + elif data_path.endswith(".jsonl"): + with open(data_path, "r") as f: + data = f.readlines() + for ii in range(len(data)): + data[ii] = json.loads(data[ii]) + else: + raise ValueError(f"Unknown file extension: {data_path}") - list_data_dict.extend(data) + list_data_dict.extend(data) - image_path = dataset["images"] # image_root + image_path = dataset["images"] # image_root - list_image_path.extend([image_path] * len(data)) + list_image_path.extend([image_path] * len(data)) - logging.warning(f"Data from {data_path} provide {len(data)} conversations.") + logging.warning(f"Data from {data_path} provide {len(data)} conversations.") assert len(list_data_dict) == len(list_image_path) logging.warning(f"{len(list_data_dict)} conversations in total.") @@ -116,9 +135,7 @@ def __init__(self, meta_path, tokenizer, multimodal_cfg): self.list_image_path = list_image_path_new self.im_patch_token = 151859 - self.im_start_token = 151857 - self.im_end_token = 151858 def multimodal_processor(self, sources, flag_num_patches): @@ -327,3 +344,96 @@ def __getitem__(self, i) -> Dict[str, paddle.Tensor]: data_dict["image"] = [paddle.zeros([3, 1024, 1024])] data_dict["image_high"] = [paddle.zeros([3, 1024, 1024])] return data_dict + + +# helpers +def pad_sequence_paddle(sequences, padding_value=0): + """ + Implement a function similar to PyTorch's pad_sequence in PaddlePaddle. + + Args: + - sequences (list of Tensor): The list of sequences to be padded. + - padding_value (float, optional): The value used for padding, default is 0. + + Returns: + - Tensor: The result of padding all sequences to the same length. + """ + # Calculate the maximum length + max_len = max([seq.shape[0] for seq in sequences]) + + # Pad sequences + padded_sequences = [] + for seq in sequences: + # Calculate the length to pad + padding_len = max_len - seq.shape[0] + + # Create a padding tensor + if padding_len > 0: + padding_tensor = paddle.full([padding_len] + list(seq.shape[1:]), padding_value, dtype=seq.dtype) + # Concatenate the original sequence and the padding tensor + padded_seq = paddle.concat([seq, padding_tensor], axis=0) + else: + padded_seq = seq + + padded_sequences.append(padded_seq) + + # Stack the padded sequences to form a batch + padded_batch = paddle.stack(padded_sequences, axis=0) + return padded_batch + + +def orig_pad_sequence( + sequences: Union[Tensor, List[Tensor]], + batch_first: bool = False, + padding_value: float = 0.0, +) -> Tensor: + if batch_first: + return pad_sequence_paddle(sequences, padding_value) + else: + assert False, "Not implemented" + + +@dataclass +class DataCollatorForSupervisedDataset(object): + tokenizer: paddlenlp.transformers.PretrainedTokenizer + + def __call__(self, instances): + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + images = [paddle.stack(instance["image"]) for instance in instances] + images_high = [paddle.stack(instance["image_high"]) for instance in instances] + images = list(zip(images, images_high)) + + pad_sequence = partial(orig_pad_sequence, batch_first=True) + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) + + labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)), + images=images, + ) + return batch + + +def make_supervised_data_module(interleave, with_box, tokenizer, data_args): + assert data_args.conversation_version == "mpt" + + train_dataset = ConversationDataset( + tokenizer=tokenizer, + # datasets=data_args.datasets, + meta_path=data_args.meta_path, + multimodal_cfg=dict( + sep_image_conv_front=data_args.sep_image_conv_front, + image_token_len=data_args.image_token_len, + image_aspect_ratio=data_args.image_aspect_ratio, + use_im_start_end=data_args.use_im_start_end, + image_processor=data_args.image_processor, + image_processor_high=data_args.image_processor_high, + box_limit=data_args.box_limit, + ), + ) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) diff --git a/paddlemix/demo_images/hospital.jpeg b/paddlemix/demo_images/hospital.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..22cbdfeeaff95cbbdeaf52f5efcdc2c5c0fd715b GIT binary patch literal 54871 zcmbTcWmp{FvhX{&y96g#umA}Gf=_UF3mz;u1a~G7+?@m`Ft`L8+#M3!oxy@T3^v^P z@3YT7=Y8&%_ulI1p4F>Y*Q#n+U0wZqUU=RB5Gu+l$O4d%kO1CPH|F=eruX!l{ zT}BlwME`digtC7LIugpiYta9<-AMTVV>c4{|6w=Me``acL;g=MjNbp&_kY;^yzsmX z2>Iu@Ap{xmLqSGHMnOeE2r3%tKZ1si_Af#Iw_yBBF#i#(|H!}gMn*>h_$V5meL`ctlNIwy7qy5Xke-8J*1Ysw_FbqsAY#c%W zG7<_3GAarh!d^tnfPW4;DiInn!&?b-5)Cs9Mi)}vpu~JkCdry^GR?^|=6B|deN@`mA_a7NQ3ktzS#U-U>wRQCkkjAFwmY&|e{(-@v;gPB7nc2Ddg~g>! z=+^em?%v=1gY%2atLvLv*xfziD)=vV5X=9UKST(BkWo=lP%-}T2MO5=F{2QnqA|Qh zCzjB_FmoYcBeGur+G$Z?mCH0&dj&Ta{iA?|MKYn&Y|G{UmpF(q5t^v zyad2SK|-7`6e0i+0Pj8CQDd{1Ds=5IvFF7lx|mw|J^EtASC8c=*1R5=n+xL#AEb`} z;XF!gP!ojHy`15rPLI3SH&ty(n_UIDlfI2i|AHZi5%uUW`3%@7TyG=@{SLD9SD#lT z*vQ#X?Mx|~y#YPGaO?Zy_8<)hXjuG4A@jb2EB&{Ht@%# z+HOizsFWG5Er7Gf3T_Vl6dX5lRL)-$K7lRP_HJt#I71mbCxMt}i*zI*Z#4hd%4%k0 z>P007UY@i~!hO9PW5(Ya>Boj?>1xCXk<@2t-V*Qs9p)o|y`Fz0rB-c<8sH%@IQV@; z#MnG_%xrDZrEU>RljM!Ue@gxgI5Qkye&~IVKG+Y(nXIB?z4*a0by``t9&;?1NnJCB z9LIU?C*e%6=&xS)k)JvisQgV!jr*F+vN?)0b0<~$Q@LUp5|RX3i;HRTJt(%q;hF+G z1s8jYldraV1L~0~E*GC^WJgWM5@sNNvlOFx9CC8p6cA5-Mu4yB@(c($a5{Jfpr*qU z$^D%%&&P>%G@k*XSXFw?QNRZ`xhVnqvoG-jN|dN{Z(tS$jQ}|G-~gRzpRuW0c?z_# zS&cq|>>a;pb;RspZ>-6Vu~F@mdruZ91)uC7jvvS&`!+~b-S-T@hh;oSPaQ)H^efbr z58nYMss{SvF{m2>3&*CpuLUwLe$iok2{e;hddi{7uymMn5l}$Ik^$>8)ST;8}&hgk>b2v z?e^=VkXp1j?OEt}kKs6xO)j=tIJ}#-y)dJ_s+e}>?JCnil5?!8L^3nkfKJ5m^yQIS zFb%t?(XW*KF#*q(QXJ34*2ju~S$e9bpE<;Wr8D2AZrpd4TXdfY_Fo;|Ag4-QEHDTt)yckT%>C+#YhkUN!&xnQFS_w-Mra>;$Tl#iu(njLukYY`c02WgEUNP8 zV3X}q9MY)T$@IWfHv#$-VBa%GM|`FhaYaituM~Vk{bKUkhRd=MJ|#i>JEBMJI*IXk ztt){S1HF|X)D%S=-J7I8<9&1+!Ipit#Ff2IJtv5k1L2@ zSL*LGAhJ|uAk2bW-z&=R@o^md42ZZ5E!I%1LVq7`?FfP7$lXhb|Fn+ww{}hJrsdo< z7VCtXMDHg?!7=opxl!)hwtVfHCksJy?AvNjXrUl?dgLn0*+2qsUAQOrv@JO=IV*Eo zW6btL#v@9ZqO=_)E z(BJ++Pi+&3rT@3u^|(U%dz0hCC1cGV>jM6px26P#z%xzyB@)S5Cr3x{ujWWY%(~^| zi<_}QRTm@D^U~IOPu0+t+~)~+$Il_pz`QFm{+?Yb3ORbAo3h(?4vn^cY*k{fH63k0 zl%{8f#_o3Xkk*f!{t#iZ6z{9o7tK$Y)x!}Us|HKm^{w{J1r3R!kWfoAN)@OI7->n* z8!yJ<3q~c#Pn_W?3G~@H#$fg|2F?YeS2(9D_hOK1?)WOd$#_Sk)ASM`qL+@aRf!(P z3Worv#HkmSPo4p}`X`mofGeZXDEQ*p;XX_9j*kHK>prvHRY9=FsB>EJd36N69L@5( zKj)NbbZx{XC+Y2tj`7gsP4Vw=23T&M&H2IHJ@se+`X9-^D0p>D==w%<*rh0O;o!ig zQ7CAs+Bj!#2$sK5`DLu7u4R<;t-z4^1#MKJZM^VH(bT=74`XhL&j2*>r^GASb#^flg#~wNtAE2S3oy*SRFYHR` z6S?NDeIS^hp!$NcJpKWz5Dwg$Wclt^=rhg{7Dc)-o&iUq1u9 zj*$h%gf|V?-o*)1J&|x%hk)s=!lIOF8@GR)38m8L_&gYCSKZ7VUWrdWWe1Wj!_qg? zV$0w699#71WyBSD;B}OO?Vz*8n3)hCI`fn_n1v$0$lBea8laUpn>%~K4T6rMv|CCa zEAka$KJV|<#t-RU7B(X*NXEa(-$2eE6SLo1MH$qBDLAs6jYs**6jzqNGxQzmQsYAkL8n9Rg(1jk^XAB+)8%3PPPy~IyTzA0=QQv z_E~xqN2c10E3Nwb+FzQw%FSu~c@s@Q#T!;?c;tEa?j6*q6T_ludc6D2gPrM3a~pgD zQ($zLHkd2V%D>Gr=lXcPcD~P%y4w~i-;hhY=FPTd_~F-6#%~ij&rw!R##9minoD+y z`}t8O*|T2%1pAu3r&ga-(3|2r?cB4N1*F7k+ot=KXTXJ+-}m2W98G#5rSF}NjZXv1 z)`KnS+U_v1PP z7+@EA@=n`l01eEd7o$&DFJ4q=i!)pD=3vB6dV2QN$L-*k^UThh)T7#x7`TP;ZEkm2 zYm~arfOun5)y(~Oa?gN*q_-OHLp(IyZnA;T0F+}O?AX5nAG2H5C1pFqSm)UyA)}uXg!Yn6 zS8tCRc$PY?9?SmW8Gtn&rUG z1mDIN2}<(6!f`+jIOzVPMgeXoifV9ebdrL*uG}d|rdh}P(NS@ESnkq#TBB?$QSrlq zoS7uQa`2C_$G2(WH^2WQL9lxLutcwa3cZ7H=+A^>^FyLxDR3FRjET<)v*8+#Ly z#+rr<9NloUfX!cL1CJY9o#4O$`1Dd-c0u4Y{C#6~kA-g5SJ2YS`HkCS&<2puJ9(?9 z7|MmdwrlVA-SN996MJd|`&QxrXCK(f{@tX#)bk{~j&Z)SXh?T}KjM`v{BmeLGvlvRn%y06 z3nSX84S41|ek9Hv-`zMr&1kSnWW8quna@1({1Yfs;Wkzt;v=E67q1r7K@6XwuYU#% zsJ-xJm6*DYRT^qJ3rwuAtFu{h`oYWf=V|A+xZ^|{nQ(bAD~xpE_b- zlO;N2#)9_LV|r?pEB{<*o#;>36DFQ%b`1MPY}Up~?YlEC^jB_sJ#1(< zCR-0vF6BcL_Pl7rKAhhFqwvLfQA0i=^ibt~l-pdVSoD5%*b67D9{~FyHN}*?k_|P- z-pP?k4#idgeC6D$)awIMHc;bmw!ees2A{^#Z|~PI8zBXLG7L&8JU# zw-`vO6#fM~Hwl`GS?+%JpnESweDmi57%g}~T%Bpr%nW@hSmfLveY#%018oHQ+jKT` z9#d~NPaP_r>#Vr-?|k1jqi&%nc$=>|@^>^M>QPap!s~x7T=(J2zt2LDNs%%zQ zcp17?juI#AJ>T28SU!+TGNlHSYrx}u75VA2VRWc+!%Nx+T&ql{FNFb|H(BJY zY$4*1O-u#nqWNh*9+mhZmyZr3TKCm_;1{(WHy5qq4KOSy0o600Cu++7eg32J*+8=0 zr&;~0Y}L6}IjaihI@!$@U*v`aIa`^cV5iU*U6sJDK&s~RSJ_))7{-Q%-;d$r79!MP z#EkF1GkIv$C^?FUP(9T4-uS{+H+V3Yyw&v1PQP9&8PnmJOc?)p>Gy*SpoxS)64eZa@((D0>BTAMIhy$$={!la-)y;yB}UX_c>2cAT|fedD!MF*^7dfDWN z%g(UIhl66~7tYarLaKw3@^#-&e}*7pX+B?{>WMptOqM>4{{Fe5W5}AW>1Wf#>hr2+ zfZQ2(!Yv&;OfS#&9QM`jjqR<#i$LqeZH0)DJ0HnU%$OJGTFVqL)GdlXf$^~3P5jii zGkkc!jjB#-FHzo1y!3I_gn6sNhQx`suOSkRB`Zkd{Ac6Q9b}L4Fc9hsy7!Y#SM^Pi z))6amb^Z?hDAi>rn(q$SzXZO+FZJVg_PI^+x-y#jt4F+1-jXx5c5qB>_zYl)-`W0B z{f!c1s&hi^=2e!$1IweZ(_n9l7d#mZ%9T}}zN$Xmy1{KSTwj>q%sNlxSr$?7&PQV*$Qk&^!MAJVEqtQiCuIhbIL}XCYxZZ;85dUG z?YTyI9++ugwow1(6w)5b8XCnmV3ELk=Ihe$Xlp`wwifHXl#^Jzw6xORrd!GUmf7i3 zlbu;NkE5UTrL}2hZk3`=6G*(C?~sxPJ3TsG-qV*LLV{yq0I67 zt;5*%`r75>&%n9S(5WjwY9GhbuWvu(S~h6Cq@jQJ+nHOeQ(%|cr^*!=PRBp_8zTG+ zI6BT#`*K*+C9vLAPLrfk%QOb1v$z<-`GjkUWW;kE4F!F>3|t3Mad&W<+RYg{71sO; zV2P?zpy4FacZnIisM0I2KyO=xGY5{kf@i)B1PTR(%&DN7UqGcTBfg1W=&PPW<0%$C zpa2o!WI?mIvEc&`hHAEuzGm1+z)$$^`E5|n-577~O%8`!2Auh-O{C%G=VM$QbiLa& z!jwpqPn2Y1JQ$KgH-CIBwA0)%eac1a-$Y`Wwr1c+^z#i@_MV`4ScfjK=Jc35;iGU0AFIb~?EBx~47c0(H@z2@^hwYL=J^8PcxUIPQ8{=tjCHs2HtT zOryr%lBME4%k;e;-Y;6^6E~#~?wYU2*0eC3^&D$`JO_t( zdEUladHWALF`0B|Hho~QV)}MECNM{UNH%<>r*M9>_Nn&gfw@(Cd&ziD4BCv%hZ0*{ z*LJpk(^ zIZ>pt@DIb;I&w;#Zz-teBh7N|2Nfp=`|yO1akunpvUi(p{3CVDi69dTky1=2G5)Zv z7)|doo}i_HZ0F7O7l(m-VOy#VjCf8dL2d7L14*BL z!a4Q*RaDp?1Kg~O#(h{Smio@*+@XuRTzkTTEO&QZs{>G^?Rm3vDJZhd(d1R z_XH!?7e0?jQGS2HxOGDU`@HeAO2zyP=%7r$*F*HFNadY-L_-F#YVQ4 z!Ba(bmHX9DvRXZB-X1Ox?9XC$cG)Ko%j`@-C$SL&Qha)c7=_ax6{k9A0w?nu_*QO( z{guOuu`h4~8_s}Xoi?2~_R~h;b*npQG=cngCp`3#aoid)0n;nmTSt@6)9353_fuRb zrm#`yGvEcxWd6t?;Lqyva&HacAVm}@ik*M?=+o?mHA)mb>lt7NqBx+#e9CUKhmK+` z;be=*wC1oAZWbg7elOt3*9sO-eLMqoW44iOIs^PA&vjEbrFa%9kJ2Qq1CDo%M!sGK z4KPrUJ%}KN<5%a5e@$yl`>>$?0R@VFpa-}#Xf7S=Q-Nq8H8#XGc)nrkF{Y`Go+YHu zJSRR)O@+VC*Atkg1cI@JACFRTM2(jSWc>Oi1yGl?lgt)lzZqjB>k} z&fL9CHFBZWp`tQXE5iZKH1DZ^-x;E4C+{MbG0{yFWW+|UpzV6l*!8300+`eO)i{cG zu~J7G;vSe6o?5{%qfDqy8nKf6L&NsxtE-WPAbOjTiFhTAYwu+E$!uy0`I5ukZ08xR z2wl|+o=|1#N?ntOxf8XEohP-}%)zDZt)nGpI8Mbg0QN^4M0>EZJFkowPaM;q0hCpB zTXdagN?N(frol@(C)uQ5hHs{QA7YZ~b_D#7kjp-~lR?$oF0QH}y&O8W9C0)qYHqvs zl3?~Gc2=il`aE%>r!=@QY`%9NYD}(%NkKP)a~Apa)o4Cm;aBZs^wok$z7p!I4v?<4gR>FY=Tmj@_Nd+3mZ3l1r?p);KdVBk3IH7l^ zrq`tV?K;FNw#r0nccq;9D`v4#VE{%z!&!VkR=vGxgh4$W>GUa|5?U_N$2DTsuV_Yv zS;+|jYRYfRdd5DDF%hv~a3W$^fxp>w@r?qpu@Z-2M zT4UI8HWmM}sITkZZyARY!t4=nDh{`v%SPfKpS_R{+DC))!N>+6rj3C~qS#!bH{`ux zrNfp^YN1UyBfp4i?0At#0>N+SiA$F!jYa8S;FHc(=h3;#S;OW$7o5?CYM{~UizD%X zF?$vtuGL**s1>fIh$K3kVUYjW;ukDSjB?ghb7@%4QSY{M==_(5jcz%%9kHpJ9SXb*Z9Fbv2SkBTdxN zG%hO{aQvx{g@Ur^oz0yvrYC6UBgt)-$*_Ot?Vsx|cq(jssuPF8(w&9u8NdU)nF}1e z`ds}!b1tSc`C9Hkm^Lav#_ZsI47wm)s%Sbh_t8xiiBbEV1;+~d9^LnTP;O=J^5hC! zzjLE`9xt7={;|IHJ+{8`77+qmR|ziWPMqiuX+qLtjT@*MN+7 zp7#9O`IJS()`xXL*jpO1g#caFIy{s`McJeyA{;L~od3H0oqO1M)& zjlAf0eHVh1%VX#+gX8P9>!G4tQ0h~>qRyhF_1AC~@(Y{@%@5}RQMwn)UxZY;Xp3WF zpiOE^O1(?yp(Tbz;TF!HsdmrhgPCIplK=%fD511soR=9h7KTEp96U+Bhr<}_BTodK z1*~eEIviCtHkE%^mWu{Xd#e}MH`imYqD5jYps!MYt{Is=CCOVaNBks*yQwN9x@LB| z695|g8=zdgTqLHQlP{a&Za76E9p|7t*KV(!)a7%JnkeZ8YD%gnera>|lxG{X4Ifpc zc5M6W_+Axb_GI&ux2zrnuvk5q1x`@~=F+8`K_8kvg{@I*f9&IoaC`>L6iHeWJp%|_ zV0w|>>HWfcF%1v7q~1S|Px*1?FIjUW-w(l1xMJMCa^8}3n_>sPFoIK%TS|!m}lN z7Qe*&4ONb$oKK0fB6hd-ctI`=+p5zq#_=`h?(|&8+*3k*`#TNzjIrbrd#_P9B{3sJ z>!>llF>%_JeNeeXEHprk_>UZl5RiLy-hF<<-$_Oh)&zdC-LW{ZZF6=kQ}Ez>2Gqy+ zK@L9E>33B!na#dnuG(p#`=L%iGe6QYc(n7twP4zdo_(IQcCe&Oo_>TYK){v+CRV^d zeTOxmja%~)`zi1Y)N2~f1qG5xFQI2>Qj6(qnbt~$e0q^e?JbHg;|)B%J3HNk2c7x# z`R07fx4qGOV{9`@1G%Y@{Yy8s@Jm+*6{Tq#zZsDx+Lh_UPZOX}BuVi(t}SX82$#Bc z=w;pQQ6drSzgYJY&%L=#g#FnXNAuR}UIN#PlbqwyF*If%vF0N4nSFc)V9E8e4}M0f zB>-3e7vmp^IyDg|<`WO5Y%yGHZ^lCK8Bo2L`3(464b^G~eMWq6mZ=48hVV%&Kgel3 z115LdOu>#qa$zNVrAfS; z8akc<2Os_Pa_!d-B5tcM%QPq8J0TvxOnMl1YD2^9UQk0dn@@Ljp7pgl!xtcK#v|LA zXpKxMkmh^DlT6?lA{MY4PF6W912Q-lcZt0pZb|I&;Yi!cwR&h zI?Dhl00K5aM&J$zL2xH1DCqwJPcRT92RbGO7A7VJf-%9y!@HO$DEcwK_X z|AlJ2V16g5>6%pAGbv?m5&V;aPs`0cxu9=m^Ni({w(hsku)5w^f%md{mLXHqP=tOO zM9&!iyL$xN@^9(|u~mi$(K})l0p(z#BY+pg#((&We=rdQC1Qq7;u4fNNy^BJfJDf& zlBPD#m?Wh%&0T+@>}%$pV&r|Y^^K=&^^|$sHr{wQuWnq3qn{BU_DmJ&>iKc+<16e|4kWf%(->$f`b81pF&}{!OYvMioy}|fKRK3%KOWDWN zN6VJJh4jRwh3^0${R87|0_XcFj?>2Oa&5DhU*A#fO3F+>n!`ehx5|C{$_qiS zia$i`1EbWx-MhCG!!A~vO6e60bz{xM3ni1uqHDnZH1$!k%^`w9czN=iWPz|>bI@vPelM0m1dg*@xXZ}_D_LaT)z4r&Mh|Qeq!KF2Vu!g<*TJrUzKJBW@{v^ zp!z@YDHwz&MLc?Zga>VM-KU1E==!-5QN(l%}c;}cC)OvafTQwDq#VBnu}dsGY(@<;Wh7?0aq zrUJjJM559pij2q%%;-#pwF0t0Iu=S%p0}s`%&dQp7h1Pi#1$U%6Mg~FN@8_@c`u|E zs-~AP?^5m6is!XKT@Q-U%zAJ_9r!17kTi;eSKNa)1iQD#Y^FZ#iA{fyep2RUoOF zU0Y?kL4qIH!`VDt>y?rs27#KDgaweG!PQw$$)`yw>>e2{TMlJ^WzNWSpyZ<3;Zza* zNsyz+4x>l5*Yxlun;r94;Je``gDQ}$%aZSM&R-yM1@)9N(Y+q+g*bz%tNidBa_a$O z?|JVg{3G3@W>cxjORz1|eU9brwUrFov*|JhkcV{-!1YBe_M2o~aDLE}b9s+vOz&oJW*A+1jxk>o&`9p%v+8e$Gc{7C^u>!WjHk9_+q)pacIdfb5&#|<@yMT z{SvXWL_^AFF>c)(SO06S2bt$?`gxo#@)22-Pm5#oCY}7aK&MrrMd966S%@z1o8TAj z@6*niQ+S`VMtbH|z_Ko#c7qp29zEFV0g_rtB9=b;ljm~gtLata>MBV=nxEs}w7&`` z#CMfUO3T#xvdPo9FVQ>vW>O*@qR>6G7fOM`(dWhLhqV;e`VHl@^din89l8Q zUXMfRvUup*TKdwgnOCbm^YB03mcKkZ7D5U?kxX@~32fD9r&{WGwJ^kxwyf>pf&nPm;}E<6;OPg`84%1i z6_V`rO9xg`AG?ZKw#Yr&bq;FnTcj*I9Pry5J0=-PMJvUOxtx7hUk?#o5vUZdsRh_@ zp`$vb>P$D1JMa*uD&1$>411>Yj};#9Ph#$sV!HG{k7JbQ0=-!>P#<6T<@ zd!k4a(jm#`#|0wR(x1+UF6n7$5?r!8ClxU#vn)Qcj|Q2Bg^?wm2@?&`!%p+rUZXy= zV1K<}WtBx1{9BB9HJKN^73}geOH$`#8-J}p3VUe4&%U2s>?%pKXS=j5 zV-e3gv}Ikr9Gc4?-#;zT%%AR>lw=0I&bFk3`!f63*q(w}PGoa=#E1)4?rp3-ss-4y zkL}5`p;dBy4B7Rwv@2}0RJ(~MtY^zN)uW0Y-AV1JyHDbG;78j<`8EzIt?PH_8$J&b zt11;ASBzSx{~(Xx7uylok>rAD{F^iT_!e<{jOiCgpVVlSp{JXpM0>HA$+cX z&D?K@r8$)+Ie~pq2x7DJ-JkOdG^uvzm(6sWu`aqVnJrzhpJ^dw>y+*a=Uk$e;lSi>O}D4f;qMdcYH_@C{?<+z z#~?sOGU<;YAr&zwSTrvzo0hjCl#370%-_si-Z|zsKTh$^#NB7+mk>L9T5SCPbQ$Ue2z_ z-;KFq%C@_CoO6{^{UxhDho7td;Hjo1YPzOfn3wuTDN2Dwu|~AtX6-Qh(z<-^OGmDJ zg$t{e$fWIAJ}=}eTDASO0aj2)W2S#2>mkbZXMJf1P7Aq@3vIR6e&MQ9h$?BuEhfWD z9>UOB<^fJ?Q9C1OsT==UU)fMhW4;DcC{Q~GHM{-gx6<8DsuFU3|E4yD#ZIe?0~6X1 znnx4TM9`5y7L;YT)JxRT-0VR$6Akz-w1+N2rgK8?=;*LiUlyKB%sV1|QS8nQmGP>gcA{xngs2**Rr4VORgLz^50}wsLD$4FP{KG%qI3YPs7}8oAwkn-y|(iz3(lFh&nY z4702-*dK(Jl(I=C+YwkR8jQ}R%^S|2}*(K@A)r|ZV0Q!5SmR|(I6{o*z$C>-&H zwR0rcB*yp&w*;*#8`SZV(Iw`^jiQp?WXa7gnq7xPFex`*VhhW842|wk_7CXwr>4-~ zxPYz1pE_(qupN(Gw7SIFGg?`G0Rg+|<41Tm{l%|p({1qw>^1_!LNbEvs~K{PQo?B@ z^|?xz*frc;hvaF-X$9W1w2RNI2=CPB{LRRI<@O8w{lv|=WT~k<>K8As=ovplY6TCC zLC<_*s}44p^Fj7l#5V7eL8$(eka|-U`aQGJgN12G=%ZoprO>ES6et=<_-iKdVMpP6 z;=aHSyrZY>y4>lK-*J8Z{E$$uJKM4*6KbucyxwO3(RLO%+MMqxW1Rj%+|FF9&EB~e zN71y*@cJbI$hMIR70y*EY$S2M*b=*;+)$|i6FaBULPVuHTlWnpKTOEZ<`vTY<8#XF z+(GDp_0%g&8`U;jusV>_$t94?{Efe_5ui|8U?Ac4vOTT-!3SPT2DVhAb59~UUN#o$ z-1Fn=nmUJZr5w~}dLx|@=H__R)1aYXq_37WUYVkK9Tn7!UFz}OD?}44jl21$g8R-V z8@UT9(~Ku^bHy{Dq8Lw%?nRMXa#q`fbWXd$1&B<8wBq0ceztb7?%SOL-zN1zm)H*= z={MmvG$M~XAK!rmzt7a)`f$d^LbuGBkw%U(aqlDfuoMQb)hpQY5bAQmmQIeunWdbB?iCO+wFQ+|kc$EJsX1|W31GRY z@DDTW?XXP5+YP|ps-+%QjvcM>gUlow%jx}7DxR#DMEt=#JohOjXqr6UF8(?Y=WnQr zcw(xY575v0k#8Ez@}Abd2333Q7!L7JsA)vLjwmc*IsMX$b%8q%W+u2-F(?kduQ-?w zdEhS9%4RIqNXD&FE_G&~H__AhP-?HiBG>psJY#rEjIQ$Aq@&&pKK|Q;*s1y^RS_lf zvI7S*Tbpca-+-PJ{h`y_Y!tCy?xe&Tv3Tsmrxivcb36dIi^;#2fV)RN!sLa&R{_~9 zXkUSN1e8=%A0GK$=b1zwR7mv2DYQ|6&18Am1C-%XOLp!VEUEei`;y{2NGrN z>%S^1$-CzD-V>x8_|E_FeKCY4HNYJiNe>hjkB^+V-nd>Arm{9l_Tlgw?4_DN;#*Lz zGC(ALP`0k=rgp_psbcUe)IU_h90>fIN<-hZC(D7XK~B0Y+CvWJoHMhRB1RUItO2_Y7}l1D`v+W5Pj zD}JwCi2CLY|IrqtJuw!8IcG^@lFLfYz$u{MY;Kr!vr^BkFPi~vutboAY}PUajGCM4 zsL2{}x3;+T`Ce$qY?lIF=-K_md~j%wy(yjM|1{1}#=Lshu?nqo0h77ye36%DdYt_N zKWQocV0&ynOWz&7Cb`ot7C1&Y;NpFYoFF>(cE7ZAGui%_!Y{xb_HB&aEFo+U&@QcHny*qu5B--qT2*FYsH`HfJhu8>P;$9bvWvIf z>b{5=1X6U^tk`(GRBxeFN95E#S=q8k`zn`*PskGOh84A`fhKvm!aePzo%-`+(Y!;4 zXV!%3sB8-g>REB)T+YvB~*!>ET(C~>>}k`Yg3=|9kpz= zaiWgRc+_oU!WiF@DV@?HwV5$sb*p!?vi?5Ma;h`If?R^^J<$dm^px$g`Cq4lm#>R) zSj(o2k=R_-)w@1oFr7>VZ`8d$cyY3OT&|n4q-3@)KCRif|AvT?^!>+(T51E0<>ap( zfg^rwK6i(5K94%@?7XYnj}CH17QFrl$7&H7{?H52;W|s3P}8vDK;2H<*+#Fub}EOy z_k>PXFC0K122a6aHovniRDCPOs0}DQz}>BG0{06`ZUO_z2~^6rHhE!Nt<ucshQD zbd}_T4Alk8mDMS~d~FZhgn*+vaqy2*G>M0m87Y0^O=a(vqh?SQ(+>Azx_TW#D@p_s z|EBDrY@^plvW^=Ye8UswGGp$d`XRRGv)ioWg`K)`21s%z6bakN#S5zEY900EsNrtm z%eW>-Yvy*{?7ay-SW(PytY8Uc>3=whcVQ;;u{~UMjQw@Y+2~ZtH90R^rq)VRxPm8L zrQ(bAA#QXCSM(CcttkMdjM~75IE#)$IMNgBdK;vhC|C1&cA;cWN@FEfO0(~WKMj6< z{T2MzrUG`4n!nf3bUEGqISynQ+Fr^3t9?!q6PQgUn_QAdT)wBt!$VAKEXQ|EOC>j-IN>oUH_rh#4?)-&@I!p6)GL zo)8t)oI-vtvhCsLS?85mG<_O_vS*G8XdvuMIn= z)x35iUCQ#Xx>AXY@l*(y1z@3l44$%c53;z4n>mFZu zOHTLnx#8k`-%D$LvMh}DyLC56F4nO4K}W37&}Rb7HI|H9IrDYN*riKjdfP~i%r9yn zY2@tAhx%2!X{`27pek5ThHSQ}*+p|ZYYWTpJkt>qBT2iFG^DT?n<+_1$kQvDV(vQD zL7fA_(9EEb*B3G)Gfyq89dRmns*gu;z|dpN6;djLRL{c7$|35##$xjGW4a@EqFkc( z58bwvDsD$(=**AUrqM`7sq?f@hQ8MjLJCYd6%K_V4RkrfDGeQXUgYRvL@uphNm0C9 z>?^#2w>foh_qyfErzXl5G9$Scs>2_l16vb9Q$RHgT!mjE;_j8Vs_1Ri zhhpfB#SGIDXjLLrDHub zx5pJAJUAxsfp_FZj-Ym3|FTsv2d7Uue6BOq1b= zyN*@MS!D?MnZ=~)6Sk)AG6)QD<zawD zD49=S&wiuL_R!DawSFkaVuxmwMYlg95zdcfW-zoA#nkF7*Gy!RCt0SPD_$y8itTED z-~i^dntSR47p5TR^xGY;qlm~Xh#lrC8&YO%91>eE(ai2>L>us@;&nxg$R;E6i6(`; zk3?smlWrm7n6Pc}&+PEfhn)_)zbF`%lzTyEJJZytby|ReSBi>;tT4^f@?jlC-Q+fB$VkNks)%k(tX($%MSRK98)ASn|so{?}~F1dY7a*qQNKblIvwF7sP8 zWX-i(kVB1EMZy*%Ns_@h3JkV}Rt#gsbL>bB?CKY=d`ZZA&y74SldMOIccjxCrM)+{ zFF(<7za$o774>00H$k5Mn#b=GDuV5?+)h?SfFhL5>?X8L_KIyq1rmD|jy&ZWCsmiB z{4MTYjPr!Q;%zLRxP`$5`FrKLOZ)E2u|{fB7ep*krKuB3mwal)96zW+UYw~(T5{H| zP^jlMwh?In?k(?^`E@_>d!e%9Yey>w(*&lkDvJhmUl~M=F?71 zj&U(?b&36SPW?)^llv;6x9(`1k=5jh{DGj;gA=eW$3ty0?>(w|td-b0=h!Z}5KYzY z)KdUi$tWDKQYjv@a#FhgJ|e*;HD>lQ9iJ$qAmP?=xDi>B!|MrF5||mG*DtESEtc2J>%a`C(t; zZl7xg93^>*ku*){8;r}Zm&fJh_+_a`Dzo3lEGcKyY<0YV5Zx z1c8xr!o~DF$~KX2%GAN#@xr-`%#ktSKY0&XxWyL@GI1pRXN2-?hv;mXe%t(r87)>& z(!o6a;9~@9WcfK2PnN=o^l=Q4PkY3c7wevFTJQ|jOBcoMS>nlM$chiToMgDhkZX@m ztEWjI3`?tE!7nTYbN>2*608YIK{}GZo}{&Ga!uD%I9>d(KQZycLl14~CMYK784#Wm z$xD`m+$q~Efk8Uu>WrSVM9=8w#D1xOY=rYzj$QK0qETc zo>rs!1*LA~VNYiMIGV1{l|oYPABI1!@N9xd$9In)#uhzE(*}?_h~PK7-5yu!c=M3t>4KiC5&r|ByUTF1W0+c zc_Wzv(SWx$`}|}rI;lB(&EmdF^I5&`NdyDB2%)wz>h@}V1GsQ`rqLT zfM1uB!`F_ptk<;UTTzuXHv(R11;9R^lanOa6l!DrRO5@-Y_~ zE?@!{7KN9EkKJSSk*C$j9&1CW5=LxF${}UNYmEV{u7+*4^D??zOQESnwLeC+$7E=* zxJy!rwN7rujT?Sf>T%rCE^1cKm+{P&YpwJA$z=Cm4d~r%px1P%5)Txp+_Qy*iH(OjH?whN@3Ig~)lvU3euOHk<$k z*l|iy_@rQW?lgS;g4`I}g+R5SK!xMnKxDe`& z&#u$@kYrC&Ylj@pW#`tOVYIrC0-GVOs42-f9rU8zi_YG!cz?avq+PJ;6~xA#^DjfW zBTH%6>vh6g3+sZnh+j}6q@aetZ3-nIYQ3x5DY@3`s59)@??!sADKTZ!)|qkT%p`=c z(%Mpj)~p;S?GflT&8lr4<J+}^Ez%Mi$1=PEPzHR~06t?N9aPi7 zFu>Y2j?SC716Oft7h5XN!BTmQBHE5^n}DjxFS4RWhGox>uGl?$Q;weEO8)%p23}A z`7vfSDkO)ZN@-!nA7N@j+E_VC(xHL~I3RCNR8{XBn??9$0rZpa`5}51#D2=oo;MP*08wglZKp5}t_tb%z;yZ0_8<}m^X%!}^ zt=KE4N0QXk_0saGjxwl=PASJ0j<6|cl9j0kj&xnAyL){bZno0Xr$1}a9!<8G z7X;daPc)c(Uv1|Qr741kI>VT8DEG>d$u1`(9MkV{F-)44qE5c9Q3hu}%=1^Rw8PCo zv2b8%o^v8Meql!1ZTY6IyRTzV>(?9x{96)C`{5KAwYR?u<8@pCIc zBxDUXQ(c<%LTtK|vL2XUKG^)VsZMoNPI`)8eq@Dxqmn`R9ask0Z&jB{v~M#m@t01u zXtF2Crc)6eh|f9JmbW9f8~9jyQkBWIsctY-14&6zc%X#I0 z^be$CJt{Mw!rP#gE=zHc()m2Y;eAB_w*Ih}8&N_*huiG!8SfHxh0&8rP*M=7k|`2H zjY*JlTybeil-q8e!NfMAjzFv3DpC_+Ws8f;o};Ubre;RvTX}@~Ew^2_ImAR1(uIut zG-6kStzdn0AgyT@W)t1Ao0U!5O~&sm(jfen`wJMq(ZJj6 zmThm=WPH~?R{d>_b9wgqur27en-tntvd(piLOq#+s17!gdQGM0RUP7#hWM$60Sn-g zoE0s_;u6}b&}`+*1$V5VDX`w!45zwO7@2bh(5*!U2ql8kv~~nzie? zKpOdDOp!Up;TjRiF0U=(wO*aI^@?@-8m?$kV5+3sRbe(!aW1?Rv?5J4o}o`w3oj|u zCjjN5MoOD-w|2{_?c?pzZtHpalzQZKQq-2<%a;{Gs)t)Gx6^3~M-H>99K&D~hZa`i z&X_8WcxKQ9jMKz_(Lo!HFE&%XpZlMYUC8rdHk?yGFI!ae|4pj=t zpw%W{G+6FVn$uE*!lom~VdaFiA-72;Kmda35S*hv=TsrWg)+|sd8e^hQc8G`yc>>hSc~|Yb28t zFW(K<1&6jXE?M_o8TKGFlQuJrNQ~kgU%b-|b4C{8+9AZHTnPbK2N^N9gF%1AHc0Y3 z2Bq2H_O}dc?W@RP;T#NZ;G*5*7go=8JDLz)IL{H!`>R>pyI$K*g!<&@)_nq1Mm-)& z%tvP5jy;lBO(sjILupEyO1T2rO1`YMl@|s^s>rY4K^XJ+>r4jHeb&&AnotMo(D`;6 z)!|rc@YX%9skgLeNRN40d{Yr_BKA47_BKRoGtWS_gPn)R<*Rq_i03~A7SPYxmn}bb zBrjC`%eJl4!OMky3n>Xd$f?ap$NX{s03Bnxb|(p8{xR??U(jMkD)IaW;uOJe)8VbO z;~LjOeyAvVjcg;~uRJe6qmrK4(DC0`1Ga`f2g^zbFA>}C&~e|sjr7nC03`O&>Vn7_ zMm}04X*^Vu+$a&{^3lmT9lMPMc?)kIs=@hap#=c9y~V)ZsjdUSKT3j6&x)j{_6!cBp_vj4Ku3uBnzKT-vdmTl_ zuKTT@v2jlOl;~IM8re<#cc5eGrMN%RM`>T()0+hpucyP_Ry_do(U}<5036hQIuCwd zElJ*(MrdQ~=s(HOfzFnK@$&DjXMBkRJ+-h;eGuuwGo%y~GrKblJ7n&e>cUjjCR?Y! zs#sao1GAA8pnUiz^=kdMTAbZ;YwEW9t;l^9D2~N(xQ+xoQtEgf3Mo(u!8j!+?$k?% z;2zbeQWE;E>yqAeL)~UeaV3%4*NPke03CH6hvPVQ63X6`;tPn3iPCs>u4&TT9}vY| zN70%HXfv1^#~#I4dCI)ComoWB-17vJPTjHxpqFj_vWNcwW|I>C0CVE_MobOJw`BYT zihQO?{;xfCrgXCaewF+~hskos^j95c>}Rlp^Hu$?YHLjJ7rhbkVpeS$i2ndGsQX0! z09}vUq`cn+J|Lo!8n>wx41egi6n_5z*D2M5IzfUzONMyI#XxoejI&#Jcu&Da0&p}Q zT9ESmCisbj70+yENdwj-=#S_6uM?zn-G$z&e5WJI+WUzwqa4~q9;(xx<0(NWzul>< z#PI6A#bzGI4>bk&W+6}8qT$-plk~K|N z3BxY#9b7-wcCbeA9M#P%PZ74aZ5@wQq!A$QGN#t%?{$^9w)L$40OcBK)yHoL^H5eK zyyd%3?&IIW%tqa5;;@8~9mhhw;uVZ()!a?w+WM}*T+Lbc%?>rmGLG|TVk@pfafKlz z4K~8aLcR9NgNehPT;W(JRd!x%D+S~vk3!%Y7;UEVvB2e!EI}mtTORy z0W82;oQ#IUX-ZoGOHoz^R7oV~S=>s}_Uh)7=v3KjbrMA8Cm|cBdemXpmNwSa8D5^# zj+p})9p-oPQyXtQCs_5}fV8eU{GkptIIY24SeAttBWioh}rmDo_eu90U*m zJ)`AMY;22pSDfpXuX59^%3boCZWP#XW5|`s5uADMv}X$^muA99Qb5XqPI02`j}qIB zYg-oO1`~9rC^U|W8-kdsWJqPU)`H86LV}7MQ;KoUBo`IgM58FuKH%M2oz-nrsg=qs zH(IG*k6n}18Rr=;I_p8D#z^nMzB$6XsAtz2sxvDj~HCn}0&lY^T5+bfPUUi0Al7S&Xrqt_UDtnTG zIfW>31cd(VY=v8Es5FX;)u6#{)|E+* z9@k=pOr|*M;=KFfMCUq~U;fwYPR#Gt4!UWQ zjaXp5i%_%u&()2tTz6}#WDpa_-9l0@MgYLU^!vVAGsplt0ptg!iyf1pXTG%L4Xl~zW5^sQB=a@BcWYNYvsXdU#HdDQr(vF0-)fu*-Njh zA!`9{I8RYz08|1%1YBG5TDa&o^%6z7NrOPAO^-47aL`bxPPD#wLdYUINyx)U3gd*S zO8^Zpg#$3j?3M+W@DwKNz>m-uP&&ont-qCeq5<#f6j`vI(6@&%HHwcX4{~`tFhhXa-%& zV>#D}(IHs*Yav+9td(>lbMw`=_$^RhgSe?4X{^m3?Sra}&%4<6)wOsdP+x>qC;64> zWApGj=AAju1|QBo1$1tOWs5uc55y_9z$LIfG3TR`*;vs`e-;nRMqu+NI`f+J2rFQB z(Sys+M58zv?V~_EdmShxAsmGH=m*pu<3cAv9!E++2|ij2o@3X(gpB#Jje(H81>Bw3NcMUiAo)jOYC`XZGqLBkhdpz4w4t)<#S!I%jAdMKtPK6y zNI8V1U&Uv>f>-1{b;mUqf{tl(2i%k^^)Fa8G>^9qYCry}kLu2)!bVH$KmmiFJqG}5 zB|r@XueW2RrcjEzYaG8eiRa%xpIIve<4TlBK%AU;Yg$UmA}M~O%UfN5*Dj#0L8ieY zlCpG?j3kT5Qm@KKKK}BSPx_}zAfi$S=h98y`-cn;;Stt-g&NRDP*h_-cpzP)HKLbQ^YDNytreq(9VUw0_NX4udc|7J1;q@e7*t)#PDJ99}*J zGm$Ix>Ux2v*J^MG@-eC}HIM43X;Bw-w?+~`iMe-U2_8vJ$t;h|Tq94W%xpcOd+h`( z-tnx}Y4nvlYm#a&5{rc6IiS~pL@F=^iI+X>?C;8kX)$cd_IP(SmiR)VeKzFt4W9H3OksGVj~ zUeDf4Z2p}*6;BHFSlh3+_N+Ug$8zGQCPUKecJ$hNaa;@++-)Vs!dzHDIAAZVIIEgd z+gxk+!HyF`>|{Z-x9kjD@6V*@X$`-(Oo~n_KL=P;#rMQ zjDP6n06#XDcUOk%veuaEaxRNmNNNKz-S*0{g5m>_(D^|_$pa}1$7LMjBhMNY7sA%> zTsF-XtxTjwaxF%ZDnTww33@se)-ATS@K?=gV5uMi2yq?r%ZuKVf(C7hfDh&)<`YJS zE?ux#N$dyYQV3pl6M*|p&Qg<~SXVVspUutZG!t;`OULWB#oCot(OWV90N8%=XJZi92dc>b008nu3A;4$ z%I)jo3I6~^aKMlU;HZfGdckXEUGk}4kbk93r%40zKHid-C8=^+6|<%#eTtfrUxTbzGZylri`bpHSc-NZBh04%fTH~#?3E@`!GwYvHmX$o=0s3izO zN|d5|008bu_-l4$DydITDups+1m_&2!*QXXhE$QGG+qcDbHAtVqit3ZJy1Vu%`T0T zYdC*`?+vQ2hqS<}NBwZf{W=%#brr$<8*}U_C(tUo^9TNxIx(iqWP^_643Iqs$64g8 z`Jd0D>Z|Umd%*bslh)AvYD2{SDa!rqYU@A2E{i5{^bFrL^pp5F>Q0PK(O+)=0KK77 zXZ$^m>ZQG(>d8862`9|wKqzdU?V(?KLDBVhd!9bY`^h_2r27Z$ehJh&F}k1R>{^_m z@roOfpZT?`T2F00=||63ZLXAo=0(3HL-VBvO~*-~av5wNlVe+qOqUquG`HiQ^iL@t zsEsFP@P4yT{7>FOZBOdE#CatYizf7mpKG(L@p0$YqDe>VM!0XlpBFyU225;u^`k8x ztPOReUU!5%dnMgni;pgiWKit(?H>W=n?R(WA`jekp8kNJ!%iBvPUdutmw%90Oj$~^_L@BW4Yd7BNXk=>yh2~rt`QR!Tj3B zduhTug9(4fZHf7lb7Wsq?8at{F z1LT|1tzaB!#TRvuqxIWIN(21`$5e0oPb*&hrQ3D?02#HfxBmbqyyU0*Z7V~stMyL) zpS+YPQjhlszsWaw)^ar9n%35>KK9>N**_LlF76(MYzs;Pp${{Szh63BQt!M<4No;-Z9>*w;fmJz=QrTkeQa z>akhc2h}a z@xkqm`Vk$z16UpMPL_i5K?pb*`Da*I&vWUa1Md-?0YkT*f(jvct89KsTW+WVM2SH^ zA<3htT!0ISImSIRsjuR-vUa+WAt0-zGEgVr3n^Fkbpg>*N=Mq&d;O&)rWz92rvgCt zfc$mHyFih&S7RQ1N-U(~x8)2yAkbD)H!#|9Dyn9&Yx)R3;vt5UXtO8Qz;iYZ9!4m;rJefW9nN31%8 z=v{Rl!^nK}@@$2$g@Up(d+U)gS(cB8Pi=I< zvX0v3NF3+=$GRN%4s)c9QKK=Q-yYha9s8H< zzBe~&NIbb+lR!Ttx==r}ub>9K=}ioItPj(#oxRMJ?cOA}4rX)Kx+)9L-zm*?&Ve>> z>pO#fe}d&bICNQ3Zxt(KXkH>G(Q3ET$grrZhj%ZjU5Ld>pvq*aOgPKjm)BCra6a<*l@OJw zDo!{JN&TXt?W`XG)GFQKOSD*6hWuY|u116miLxXu=J4y_Wlb4#3Xx>5&WlL>_ntdb{Cy~HXd%5*$=a4mZukeFjvnpGhj1mvE z>keaio=YcA%M%o6y_LpdK#e_9@(`n9d`sNx+zM6MB7+ILj+;)XQz6KKCCIEM6oyN2 zqN#zDskpg34`nSODO!jbDm&A~Jj>T=u8IXl*-FgXiwmU2lHwnS4l{CFkfc174su*^ zP2pK7KnDm)22xUfBP=L(J=bBTm1{9;(^G4;86JYAzBGrbi`!ae>@} zsOROa9j+UM-qNRX%5lRzAf9|ye&YWC6SS!Nchs5JEo2T{6t<)Im~EY;xggc)QdziW zQl`_O$DxHrrgY@SQvpu}$XWL~i&q4-f>MkW_t!$m->Wu{PqmYBNUT&BO|$-plO9dA z8m8NlN+z}>CC6l;X$eHYf)X<0VMrw)>KyuO2R^!#{0|R%Tp)=vmt=3ADVAdwt#e%< z$LDgK+j-(L!Q1_`he(Y`H5s(J#JYtyg)LO*&Mo9Owk4 zv~g2q#RU*bPEr$qla?%<5AV=5p|-YVr2$2RkT|MKitf0qYPdl>6NR+32sFPG8+lNv zMvX2DHxx4JWiaYPVL{^J0`ENf~?(cpHgaV<4%nlaH+@z zff`fpu4!SlhZWNuM4kyrP<1$8xZbPQ%(SX^&gIgrT0)s_p*_VRrN|5Er;19{RlzD& zFq7ny>G9LSdU$EPtZD;wnx$^ucAVv9x^J{H`{^hSI2&#Cg)#_1Jc&rk&=8_WmU7so zy@hM~POV#E<3Yqvn$cm^?3Jk2J3oWDsq4JEllBSy}TFq-8KP-H-_|<$qUzB~2o48L%p}lIwj<=U4Jxr^b zQt51l!E6T!P$Y3ADL7CiCpu5uJ2SY}t%&8#Ubojzp(2+Kq3v!g1TDEmw6>K006D;2 zZ<+^mav9UOP_VJ4Y16J9Bn-&Z03kJrdTraqN731UPFx8lQ(doV&5yHtmn}4764zYh z)gXI(GZu7bvlx;C%i!VZAdpfN=Ml&h{h&sh1gI$qP&rC~BktCkoq9E=t8g*E%AkE3 zw7s^G+J)<{YXs}wnbXBGv#gIjV_wl{f^)Xgz*gecDzsRV(H^iGWWWoQ+i3+(fa9(u zLn~>v!qlz`3I~s60&w>`;i3b3@@3!P@t z8&B3{N%wn=a-?`+x}xtrg=fjDFqB+~eV?Y8jHjTnF^3&PkM4swp6n3C;Dz9t^=9m;+4yiaFNKM?~Ttw z+?xjDLKP|7awkSZFH0->MQV0Si;$wEklV#F-9T|5sUVV)IUPe|IzSjD!gfy$eQ5b| z_mwU0!8Xh^#Ev>;KR?b(mwfF#LdDzaeZe}c)TGU;y8*a<#On_>lrkL-EG%Gh2e3g} zNFar%9?d?T%b$2_*M7u9vnwzu%YB;KsxDO7r$=emL1it3l9u^%DhgWM3d;FmS;m?@ zXzYD2sEe;^PMkL?8_!1Tf7AOh?^r%s596V0-nz1d1QqW+czJ6d_&Uy%h^*?@yb3HI z3^la*ak$?7Ei`phI<;>C3Ma#T#-BbnBcG*)wa$8MCj&ps)xGF>j9`B$RszNY1GlD( zS99>u46Ew#pDk=>-q)MXi_l>U1CPr>)&ci)A@+1)64o$D8py!N#I<%h>hpVuX8}m{p zG@<^pD@X0sR&}z{Ncef`1y5~AZ0xO`p&mr&B0_`hljY23^XpK$bMw}^-zhmhK=;!G zql9vtKvIV+;E~%!Da7OBuA<;K)ltdw&z6fYo5nloQi)yjkGzGF4oE)@Y%?qo??n}nK?)~&utharCe|_bKGe(Bp#7D z2e=wBm((5@K4-`s9?}VH-*k@o82Ia5X$D%AofjQPpb1gXhhpN7snO~ZQdW|Xy}-C% z2?as6n=8ksC6GT}l|#!udPm#KQIR%kQt-9!n>m|;(5}R| zsII7?GNHN2Aw$gy2MR)lMpB>;e=S&4v$~~3R(&@H9SH^xeu@h&t*TWtnDAUR0Er#i zrBDNx-J=2g+GuJU@3Rz}rlCWs?tQ&XtH6-B^CC)`TU#!m;Ye0fjDy?HTh!g?+8Vu5 zEE-*3D#jj zGgwjb4kVoY`uV!Jq#2uqUzePj6dol!(boCsy?8Ij*WPfevcl_6?t z>cQYt7pVGT<&JTEpb;7Olf#tju4^syT?aM07gIBra2$KI(O2bdcU4ZWPK#cRPQO%_-B_U}ffTWY=G6=}+ z*y_q(xoMDmX5E(+(J3rNrdh~=Hu(=iTLK$zEu|Q67a~d=_Ov0mt+av_hh1506%}iO zn~KG_d|`NOlI?3}P&5HU z)!CemQv}AJj`B26YT+y{C$7(GGu9*yKZgRVybwlGK?C4*;(Pq{KQ}gIr*n8;-fGnf z=A<~K)!?n+C1uIMYMS$@jJ2jeQ7R##lr-b6B_#^tv>=3t21n0Y3_ASZ(sRo;2r&n+ z9s`P{zO(cyfxBB$CVR-@cr0t$_Pwi>Sc9JW#yjI#AZys|pn_ZNrPB(@-@8)lsl4B3wwz30#K2Je%%7Aoj>TkFKPh#o}UxTcX;m#fNj!W>aqHHCa+4%4N44iySRW zWeP&^$xG>x$Z;HV<^a@w+DlEF%X}%+d32ZSG~mn6OphGA6cCggE9!o;vdTWlK7Lrv zq*nIqU7KXv7VNq`60bvnA^LyPl~qTnPE=tbvR0I~<49>(aHT0gbUU0RjOQ4vOX#;^ zG}&+od}KzEp1;{@>^(iU4|NhYCpp)0?p?kn?27*Yhr71dr>3zQT>46)(^;2AZ1fbQ zG?gL8QzWIqmet>aMic=_1G-MX;zl#}_3sLc#zN)NqeMg1S`$GxB!{F0)aA`#vRiDr zgtUa9DC2} zM`^L7ZEU&5Bx*!X?KJ-C>Zj9{9^sURiH-(i^YTHWc#GTR)V8{vU8U3&OQG7xhV&rKt$Qm$dJ1Q~R&Us;OE<3%bx7F~2Yw$c2QsbD9O0|2W7 zBxeIS)aGl#ZKomyxbeLzu@yd-)|iDN68(}=nQ6qcmmFsyDR-=_Wks|v7&uWw+uj;s zSFJs=s|&hhb}omOnJN_49ZVpJXpB*YwvLd}TAd-3FQ|syQWC%C0HQPcG)4*5=nzDO zkt1ZmLf5mPdi*t>`Us&+UkOre`h302s$8$rBU0I>Me4m&Z4NZjLfu-Hq$|w{;F3?5 zAc5vU)9-OotTy~fTa{T<`g^qbmW@)SGN&q)6>#9idDer{W`5avcY)iLA zCZ>i5{mli;Clg>7Tf!vi2!&ocp3u>5y>bm4H#AY(cL~-w1D@(hE zSu&ewZ2)pSOwSJGCMXa#iml@2%#vvgQz5;#zS%B^RD#2?TWi=8A|sK~wwx}!5K2~p zlDxEB=A@M%vzygy+jQN8+j_;et9490P^hr;^r=zQ{L&kWLe?2t(yk;gmlBkuxSy+N zuZTi9v)i8+bEsC_b{Uy-)Rg>se0qIxm!887D7unRvQU(uDLwE!(mopDD{l80+q&Gc z9HCG8ZlM)ZX;e4fOKznrTF|r7aX>V(rWKScijM)1dui)yYBa7X)osurqe689Ydv}E z$un_YtLVO)t>C!p8WXdYoV4bf-2NKLI66~StedXgw<6XnbSX62T&YWq7E`=>avVxR z&rLv(wDBN?rzk#z;OQ-`qbgIV-KG#nEY?}{+`|qKP_wUk_BsxA?KqHy=Z$Ce=sNaJ zf)IzRPVg3@zYn)+Gxjb*e>`@zRvq*5)uDI@dmA@;wdMuFjQ%mv{u<}KH6QpVe=~k7 zbI}j}X}ym4D-UTv5AF zl;ia3su3855|ijLom(dSqpRVCD(^JTl*f5c4s&u7%{2PrJhs%9 z-)sV=lu|%g2ewjhMtfvwBWHkZX)gZ&!qnuiKXUjF?CWOpxFlMeLv!5`DDK=6X%xq% zqMFOSHOlFykwR62gMdzOe7vw^Ip+xvXJxv22V_>RjnRW9VyPm{Sz%)H>x*t>6-a!c27_Axlw8))Ej3 zg1ck5(|6z+qgYtn$W1>n>k=gR%ci1HBsBV#;UPYWQDin&2d@%)opo26{{W${1?JaU zHUZl9Vnl34Q`wX*V9{c6dqZin)t#a?eW{H0r#MV@-OzTv{{7Hh-Tn6JP@vm}ZB9g0 zPOg0%mB=%~Py?LP$b@q|bMy4n#j|`qcw2LGsP$_ip&qSYrAcn16`FPH?WHPZX;N0E z!qiGaxaU2+J#`NJHti*uz8m4XOTCz#E(9;(as7sRqbNH3w=zipMiS`QN`K z69KY)!#d~Crsfy8?H_Ai&D*n21?4x`c7E`hP6HGfu9N*3&L7#YoY?KDapN?A3d#_a zf}oL-02wE+`_FBC0`Dxr3&IZR%L?LKwCV&A=uCB|_UeG4IEHGF*6Nam+jRwMRCK-) z0)|O#SWANg>>kSUJ&vdJ1*A~8e>VIIh0})vTHTzF#3>gm!p`8tudbf`Ue%84j7Lp7 z+U{7a{Uk@N=+0xuU6|>~rlxbUH7fcZ<>Yn@6i+Y?xg|?VtDjQ@WVWg!D zDQAj;fp>sc6ekKDr12kQ+SWyrO1SEmAaaiEhpGrwE<3Sh88VS2XlW?Q$4yO+Dq54v zn!5r4)|uB`BW0&rSPPEmGrv*H!p*YV0e#n3wm=)?M(q!ZjWx}cz3uzrp>4kon=Zc< zFjkllDh)XzSGKEz9?sd~ruUQYdxx%weYL&9)`J%WcGIq{?~Kj;9PL zDZ0I5>SPo4Sf-*zFA=RoYGQQH4}$vvkOVM{TE*?DAS%NKL-D z?UZqC3g!?9D?@kwlru%=Iq66SibRCebwX`h#}B?tqkd`7vj*d<| z;*JuKK7W{bp4ym^?;8f<*zLZFOu1-Ve?YdRKAV!?sdtrKjKoMQX%4W`K%4DSV!4o2 z?(<1MXH2HsRcOz6khtBYu!~{3)greJ+iqOMl+&b2it-|~r-(?%=g|&e46Gj7v!{qx zeLAjnGSbh(>@sio0pRfIKhl&X9;B_X4d z2PHNX05~`~!N+Xt-;;M2(+yLmrK-{P*{R^m!sNhu(V5O6{D z(t@YBcZB`hhkdkNR zB>lcod71GoW5?NFX`+dZhs()91+o|^3R7Whgs+J?I!;k3ky+Z^y=BTVsMW5U$UNID zvf8A%4lu>;tFWfl z8;}yTI@;GCjuMrEgY0-6!B-Qjn*RVq+qf{>&e4pcj1%Hd<`q>Bq-%oVSdEc8`(xY` zDXP@w)~;%WVxsS~Pn#ic$a8{G$!a+#2>=hc05#9=^VOT)TUW94Yodbhtx-RuJIkmlPGm0@T^tO}4e$3fM(@mh~=m#F|VvkM$uz zkp5yJ^%%{)G?O9KKGK%pP8Qk{6515%j-92_8$PHI#|DTSxsE%bBCZQ^D(shnbBHJG2{(x0UMaR0*n#)mB_4pIu%`l9<7HOvcax6p+(uYbjV;iYitcRy#^6iQucg*vTGlT%GbpY2XE zh;n+tane?lrP0YnDn3LEV^2M+;xdHZ>r9(nqaxR!vkqKpVYTTf>a0krqH@;jPdKC{ z#ic4rmYm^A&NJLJw;h*k(|0#_nTr~9Y6Y!NDTdSY(DSBZ+-5sLP9Y#IJx3gK7zctd ztPpi6?avCkb#d7is*^Um2DMhTktI&5Ykd?-R|h(QDr^#rI?{OK99B+3faO_;)xp2t zxM}sk%yt9g3Vn+)L;4B?9fL4?W8#sw4~PpEx4jnK)hxL6K0PXI8ucBEY^A!4Wg(!{ zHrDIwTPbxdA*UP*JOEOXJ>s2nsXioOzrNc;aNNO}aj6mBkyVu3B{uXKa$c3yY%CuTq6TaiqL zx-GV-s7l?#Y$PZtSxCydnNKB^lbuITT}rwKhblKlM4P(*07OO+bV{JcW|JL&WxC2a zB}jH#?6ykzHlW&huey>#RFjR1bHhz4v3uTBi8t(~T(#vUJzb_KEvd$e+fM~^5|q#J zt0>@=t$1x@gyXgCt|5e2Qgt>w?2Pnr)-=nr#Z_0b#IGtf?`V^jplRNsX0hw1!l|?o zsD3MA?M*`CREm_ggm$I30Zb-5_Y$syE}>;ykX9Q7ZK#eeg(RF2raYx-R~4k8VBtwV zX+ERy){>3Q8-NWnG6C*7d=(X2dXxzd+(_>WI`*;O&(~ShwH0UMt2Xcspmx^&Y@aII z3di2S{{RhEMy!*-Q_4-VxUG)BR3a<$wz2y)f$5N(2B&L&A#M65fAehYcfnbn0Z1N0 z8t_gNodG8yBza@DjyVJ4uQIPhAcAmoU{ZBn)wIj{g9Dt+RDw zHvZ3EWlaUI=qVfuyW;~rwNX-D0;@4O@_ApZ>bkb7Lo{1JbsUDfs zXeq_hLrGIW{B$mo0V5hFBiVqS{zU%0Y;^(8eCg4mL03Sz1w?|c%rT(_VZ0AtiPsQ5 zO0qM=Yhj06bw`h2K02_J$XypBrDQ0MMI*LGvKu@C0()b=vf7G=Z%)K(c>w{iPukFg zXbmHaTvR@J)}-}trN}tXxNtSlD?vO#9I@rDHF3@s?3E70C%%Ev7+hODByIc!d36-? zpVocDrAQ=sdUeX|aGz5|q3U*d>DjJljwXQml04b^44UyaZpP5~L$%{~q z?OiPghB!yaHxCNJ;ij0QPQK;1je#0@=(EE(YUk3*#!=gLx$4UZJq$eON9r%zuHPPxjVYT)4 zu6l?UOiyZ*>kUMf%wQdDo(M~B^%M%%g60*|#ukz96tou(^4%9rzj(&0p`k9lQKra5 zh-*}IQ&@o8%}0qDwDv0eIqD}F$$ZMxPNmjc(19y`#Vj>C=4ZHX#% z90J);ilEYiA#8xIDF-C$LFoM=Z4d=OO!CubO*rz`Gm%-|zXWwsO)>;~_OF>V%Z3?%7g(?Ch__Jyu@prtqp-tx1~Yw*4kl2=iKI z8rLD<>+ZIL8ES0~V=8Zwk0mUHfTEo;;nHDI+dj-?n{|&g@I4>u+dkCiu9M`8yZko9 zwdvN}hcCGh+NsN8JgHT?Vd#-*s(njSNs#JXbEM}<9E(a%K1YZHMqXYPYL>rBc)ujarc=RWZnBVk>Q| zI;FeRKgJ_9Ats; z&~OOB86KePB;`cw8sDr+_JUG^oD?LFfk?))op_$w2|`kW*tDiyb(iV%$ElHIA+#?Q z322OmO5Q_BN>rhRE0c*leDD-TNXC~p2E}eYu~)Te)jA7EpK#1#9U?1I5T{>#ikXjYR%q6B)P=kT(;o#Ar)`DhPE9C%7~Gh302IWiQd@Ew0hsJ0 zDc7h~XVuq-)n2?rRo|M!Oww3{oUP@5gO}blo!amv&)8&B` zm)NtO(CQgX9FOk-$AR4=ntj%ods^ZvXA(2yPsfZNN*T2(OA@0#%!gTiI)bCP+p0>M zAOO%nFTe5y1RSJ;-?lsV*JU?#+TAr!k#$y)4NVr699857$_Va)o_OQ(#x$SE*NOP+ zWa@O^s0C&lQlov6LRkLzm-wQ}rv6)3O+HLD!jl};UMO{itgUS*6?aOMr71x9l5?GO zDLc8k6cB{bw*tLQLS93`EL2vMm7%0Ef=80Lw$o?;9 zg~f1WIQXT7KX7kKTsK&OaLHk{QI|)@6fCz1YFd`VYEnsUDlWQD0+M;C6N9e3v%A}k zX4>hZTQ@V5npjHH?C66Tz@)f&EhMd_DL|hvNgi5l{{WMpJ!3tzPPDaqOI=N*9ziLW z7Yrb^1vYQRznjj3T7K^A;q=H4J~1`1fT-qhin3+WaQ&Z@ApD( za&4H@#%>lNqAf90*m7d98Ic*KC6u0NCLc4Y>4s4xeqrpG}Jjoe@!@IqV5Uh89}rCH|mLJ!Ch@DN*b>DMck% zPz67JZ#b5`+B}$)YQ!09Q_rGQM;i)ta~znru*0o@wK(Afj1DPC1Gc1g-tP6GeR!_6 zrpkpfnju@YFQi_pw*}c%X*~kxTrMnTi;@&lR6~RzA+V(JP|BOF*rL$xEsxqJ?Hc70 zplTNs+H_PvN$Wo2sDgsx)|l-jYD%zHr2xOIMtPoU)OJ*~Hv)xgP{QoV=j$Cn<;Qw? z?6~SpT2~uExem`%Z#@0Uy&Ly(ySViGilaxR)XY1!UDFy0-I9hAP=BX|`>yif(P$K) z5D8NG^3HS(tKT}#)Hzd;N4M!US#zc;ORvmvg*k@R4h(rq=m{*gq{nq!hIpY|7F$jN zxML%cI%vw2rE*e=)_bW(%}3OoPKwp1wX~YOq=|M@N`hOg##9x@klSK#g-=R^$60OX z+e~*IkgstLaCijc1nbvU+wH49IV=lmHKhg0U1nT()fg;RB^LNDp3>5Ke#*x|MOjc< zU2v$CDM(2rA^U7z&XEC>oCe8_G;HOWP1eJ1YRP>@pICUhGtohRFFkub@u{2Q9>TC~ z-HfGNR?VXhg;SGVMLox&83<2IM^lL%CAw0TaX<=L3N8{=?i}VA_5;saJU;rKA#c|( z5GFI^sIP3_WnZd!1nZI(PCIJFJO!^T?n`;{eiHHitE2X6y1g}FUIJJD0G2l`zdx&j zuhQ_=Yo+3F3cvCH0I*svikQaff6M-*T10haf$5EH;PUxvMmQy;dV)33NF`hdImUC2 z-=AJ!UX58w4^;XMZ3fS8hJuukd+70mqyRfFF?P4PdF;$;%^76*FItUo|v!}rjnj8k`RpN3a3f`0R4Jc(61{~ zaNyKs+BJ6TEx=L{B(}?qmeinBk?*8uAt%e>q?c6!N|Ccqa!j`q;7#aNfYIcS{viNwjSRU^hYE(zfWUKqyci13z zPBW>%WE4ZH6G^}YVAJ8KkSpGQOAB=&NPL-+#isWo9;Qbq;~nF zkK3i9U7Ay6i(Zt${{SZilzwK?HPF|%1YrPRBi#UeyME1N44p|Sg^YI4lBAD41cEdX z=cMmOQT%};S;z-P1K;7Sl72cgl@xqy%jkFotF_ApO?I5u--k;70J*B3+v`oc+jZ&I z{FYN|VThws4$`k_WD=@5g+;hl8pdZ1!W`e~a&KxmQ2GGR&W!`pjl6^qi= zw&Ny({@l?Wa%FaEz1xN5rKt_ZbbyxGloI5Cj^w(Ql7OB`$(BplZJVcFmf9p))KobC z07jl#sn%pgj_TIbwLI?>y5V{BrZ*&~ErIVSr960=`nmUKa&IeYBpaURkyWN5V{=fI zL}8@JkkUG6dBwKcOY1^U5T^==1pIWR9wWDuKORZEP`-oCIKgs@1;jS%fn8S<$$|rm z4mrtRX+dhq!8)+G!*QClnATQ;CVgjTF}^pMF}pykXkX!Us=dynOvHMlF6rSJ@7|xi zp4uIxx$S4G_4+)jhtjpF)XHpXeZ(Du4WDd<1*e;FO>=!{RMonN6hc+gPEd2$b}r0q zG|v#1t-B(VSFh7!S=73;DAQ;1b}*LGR_leqj-Vc`+kVU zGKY5xT{)Ll>k(topry8yw|dGbPea}M%qs)1Avnhog5vMy=Rv-z_bRKg9R_&K=qxLVmEL>m)J;cFlputpHye8 z?^15<*?pXYy}Ko5#kZ=gIdfHN~Se@sHh=cA(t9gaIzA3;udZU2n zk&*9G&b{G-YH5`^>s+ozlG7<%a2+Z{bTa#mI^x`Ex5-k2OHzEf7rJS}hFu6`=6Ijn+JfZ$ zAgUDfI?7PCj_(H5<`elZRMchepuoYjXvvckad+DDi|Xy#j<0z?7&~npjFTUv^8BkAfaH0-C7;kf}ZIf;rYS!2G_pIH~r0a(o zUe4t$ZGQ_7T$>Y5xxXmtj8$q-;ZBLuw$|(hS$;I9)|VX)a#s-@0ZsrCf}jXjAX0Si z4RlzZ7v0<@tF@_4N2AA=EtClTCf$A;avMP*Cx|Lqlf_C32})J|Q6o`Ns?}+mN@WF^ z#kyRntjv9oJ>@#fZnV<5jEXvebbpHUO6-`PVqeLnZ zl-gTUeEM9wr3e|$eZ-vP1FCAaxY2ymxDmB65tj`Q=RNB~O792^htbX>S$-d? zuHV~Rin|{6l{M;XP+x7QXb!N4S$VW|oL3?io-1_$DG4eg(M0>DT)r40Q(@7OMZcuZ zsZ^?Qp)%D%g1GGtv*oT`M5$guJo|t#+d052JNuh_$+afYWmVHApD8g=>K=@RtAt>{ zLK310N>pOUQ-FUMrjk{iDrIKi_w5^&&Hrocr^O?%JNMQ=?cZ3|dx7WQS4C%1{>R%}ivIwn-16v4fQD#LQ!Zhd zEHtGhP9+>$7$s`=vw%G@-AYPSxSiq7{{V%#5$$9@Z_TSc1=)@JJzI(kZAkWaJFN*M zhdXuzU}rkt*!|kKHrCj_srA^Em!GdN+cBimRF$l<}*j5hVm zqfS}4#-;}O;-NM=*aDbmVSjwKDW<87ppkR3i)*Qxuhx0ZJ1 z(BGlTtX$MVa#XGwl((p?u_2eFH1jOSVW(UvQ$-7CzLl#BEhmQ_`Ej>cyBGypAi55k z6Q2S=%bs(RL+!BFJepNKSsL)m$K?s5le0UW8}e}u+4C07eqXXv9%!M+&%7N|Nk?1~ z$SP7*z?I=i_mW8Ou8nhRH;&V#J$zN@^$L41(8SeNTCB3`jipLT8GJa?%3d<01g$?2 zgQ+0?`E7f(`(9idma>Q}*J%=DM5|LAQ;~ynWKEXpi)BrzDC;4lB^goRe8vuq-7V|e z%d^7eeuqxGnw?)&5iy&FXtJ3yo)RIDq&TD{1Kf_hg^b`N1z>SCgA&D^R-0WzCm24f zvJ>Q|`VJZ?yo%X@Bi0(ihc4U~7+eDpnL*Ss`ijO1$uNy0!Nv z(Y5?b+Lfsg6JEEdmi!7tC!CLz$8ZLfB40^qB}FS-wIl!)$fS)eqHZmVU~e72UV6OV z<4#3xr#W?cB3P||!ohG@GYrv2Wd$P};DB+r6DsVR~npF>P>eTy%5YZto(kKr^cZqrQ1tm>4$KG1d zp^RkUdoo7O?sGcmHykKRsH5q->L239o{>?xJ>^5I*%D${!b`=-A?Ia zX8^C={{Ud3bY;)tUXeo{eYJu+ zon=@U)(IH-X+Z=r<>#W-Mo7>Vk=zmnic_iNQhf2}p(rB(`1$CPa#z@opP|-5vf%(6 zV>#9U$UXFuf=Jg9^&CJMa4jAG0M0=gsfVe^aoLmvT;*{g4?sRCpKEny*Y z0(+#K`>H;lcB=*0d8^jaqy()sC)?t$Oe~F5U%O_UsdU0mHAiCu4T2OhJ7{iF8{>e4 zkHbpH$ZZjtNyj9o&qfwfwUB(sz}8^ZDptxtwM1cWz8orSjyMVU9O&UfiZSo3{%tFp zh})G1I4d7il6!pBqfI<{^ZgpMr&d@qQQXtCskz1$%Rs3~BjS^tO}`56rXJ8;@|H2z ztOSBQ{;oQI{{THgNoz+-iX98d9h2pOsiEQhPR8)w@j>Eb)a9=@^X!BV?CVimI&he$ z0qX8{<)x=+ZC2JisWHhvGT1-2PErX$>;XQyN7{`DwskHhFd2r^WBpVle$6fu#4tMy z<3J>$iBR>)IzrxAa+MY)a9^0x56_c{AGb?&ug^{W#9fhiTd{!wE>u}0ACkh4_H?1B zKanM#b&K5Ui=|jzgpm^l@?j0Myh!@WD)OcK+3+I^>w`)9u+Z zIRbc5?0oS)S?I^IzLf$%@69O*IO1 zs&f?jv}$`#youE-fl9-2mXtE2gVxRoU!^^#T+4m7Hixbm$Eb&%bS2-6a`C%;w{ z+A`dXhZ#$e>Z(%VBAi3{5lblhBs2rLCsknmMxiF2B1xYrL_p_IFzl7Bu)ZNYxRZ`t zKu&R-06d1aU$0kXTMbbp*D5k-jSint7VV0P56^7ULHGTLt>N^DT8g^+78&NoZ}wft4qe zT5b|BJ>&uQ#*O=EXPuwV6D)5 z1*aME)q0|i0&7%6edND|Dn3|Qk z`wDu_oM2?4?c2w+=fY-RE5)^ukl#$Gx9lh40h?!-%3dN-3HMnFTTvi0$@L*ex%67l4Ly+I}b>8))-2I4XwhYv=pvN2?w%r zNhcm>zO8WlDz%$SUsGyGJNHgrG5IRC^AM$U;O4R*5#t{vWw7?XgKq6TDt%jQTdc&F zYsII$+F7=Z*H*6W%p0DcBBffV z(Y>}1Wlk{MxecZ>a$T9qM3-5N>KHAuo+@mDr6}qSQ7=2b+mY3`uByce&Z=Dsfm3yo z7L+LoSGu9xCB>v4FczRZ^|e6U3+mM}jkS2zs3uPm{{XR6oNuU5r70xf`GU2Oe8EvZ zWa_QhE%|H5T{}mMhjv8AcDvPp#cczcdZ>ZO{Il=TN{`;c%AxLNepO9j*_E3LTFl;|rlNQ{CC1c+O zVC02*gOGi7sZ_33E0Dc9U^V+$)MB#aCnObmHydBbnZOM3H| z0&|?dGc!HP7{0RvJF1u@#_az9F%b%TZa&oQoL>{xQ&#P%1Tl=Mu8CAn@)M5o+v@QEH! zO3(`{jE0>EQBH8As~|*Icc*qXxKuVNyqC009 z*2L}e-U^&{8K~|(a(hh5YEwv&B`XMdrS%(dWj|>u_ksBR=T(ftadtY9HKa5_^_`w) zmkhI34#aTRHlP&XyG({z1ugHLsoIzuA786NwW%;G(QRo{;-(y@$A4!~4scWMa|wC3 zIIAQjK%|q$I*m4^t3Xxbvee3z5xTq8rN30JwDiEoQVQQwETfX1X(Nkq-FN`B6>|e8 zAZZT%>24INHCCT4@2Ja#QHWwnW%m-|ODaofQ)uK5O(iU}k{jcYfDU`*9Z;!Hs>qJ* zN>b*nh^<9ryMw`&k`UWqc?6{(05xf6hhNd8=2fdRzDe9iE#`IHtFEtYD%20BP9t;< zBgszN=Va{&w#Ba&4TnvIQL4(LwB(uvan~J6Uy`Irc_LKRAU34*QrO_)vbB{c1ZP{j ztFc>CYV9reT(_Y`Pp~skcu$>ImhvGK%W;V4Y%s2#PaX}W(4w_Au&k*n#1g37HEOil zB#MNT$a;}aj{=;oiV|31jFyl%fu1~%bCPr1ojCQa{M~DYiB_goE`-gd)Y~o3h^kaT z)9KG7@D9LZ$f%3+C;_AN+e2^1(dqBo?Rsid3~Pry))S zb8HnHAt5R9>_`V)8dr!5>Wxg920VI0@Z?lqo6?gR^@fz_S0>vb@T4|6ju4P@l^m5| zVCq_9zPW8pPFe2o%a@uHZVj^r(B2B|yG*NnIZcT2Oi89Ik<{x; zLR?2AIFzU;C20iXiStfD8b$vACqBAU-d9yZ_T0L?)m367YMn{o^rS&)O|n`ung?OEM9cN1Tg3|oq&Ms4yP zh#xM6sH7ZqP(ncf6<~l=kVweOn@aGtUaanQ<8Hycpa$*o!FF4;r^tB@C#a*v1zZ-E z$UQQ76)7n^auudV#oQZ%Zs%2`+s(jKiiC79^Rpu<#kPbw&kd<6Nk&cruzP`xD%;;n z>6@OdQ(>A^b5pESDzl?84M=I%;1m{*W%3-<(iGvzQiwPyzEUx!P!s|qBB>$J8TcIA~rTAhz123(F(6ymUqgsYkg z0)lg%~gPP6J=Cd_vRoaK<%6^Z6BT(bfDB9lVCO53sq3jrgE!5=*o*-F?}dw1=jqaY|?ckiN5pg9j9Jv0h1k{3{HidvQCDopS=J&qUO zx0h`9)neTh5lY=_x)X;bHK|G;AL9dz2U$d2BpRD#k}ow!m?jde8rMp~ELbdijG=lpb_HzYyk3C5!}uTgb%35eJC4(8(IYt-FP5C?ygE*M>EOS2 zkpL)z#~^B9Y-})(47Q?5efci9zrUQak@_*H$-Q;kdaHTc(I|D6E=hE{NsU;m(^@Rk zCP^--e4#UhhS~@yNyBaPT2`V7Je4x`P5}MsHkO~q`0u0n40re6O3&-m+d8%DSkrxJ z+B`%ZOh|#c4ui00t#L|rlraj7wT*!h21bH?^St(*Rm0PdMmuO2>~Z@vqYnUu;O7I& z*IXDW3XX&72|uS@6eNPRpR`6bZ=qS>gqM`2lu7TMGCRY#&^Ef*X8Ed8q*mk8BsGXp zWF?m!TdV0O-%ua{$f)-3jXcXx5%4_!09K(NjY-L8!;4W0k12JCi2L~U9RB95@T+!S zVsvOfHo3qa$<0<`cMrA0t5#;MX(P5WT-N-5S`DP6ZB=zCB%BKLZ!`Y@v(~~tbOO#J zD%$G!`>Rn;_cd39_tpU;$m3p)Jq^Hqto^5Xt_A*W#{U3(eUPmm{ol4v{{XgE2>$@y z@lW@3H2D3ovUzQ-#y;w#L;E#X0O!nW-nuXHZUgma?H@E>=In3x$Jq+fzuos?XTIEA zqyGSS#L$P2+W;Q>b94R`2%qDsw(`(>6O-FR{F{LO;rm4o^L91+Wn{c}xopKK3+ z&_F6sNJA z{4Y;G_1y>9HcK6@%k+Krq!0Ucn!==|IZ}Wh5_Bwo2kQR-lI6#r@_A!8&z5oD9(dHh zM{xbCz29|4^nib4-vj2I52Ln6k88U79sdCJP6z$JK{V}@KmM-1hu>*I{{Xk?uxuBW z638zof>bUu#RJpTar!ucP4p#=W`*Yp!i z*-uN(_)p(w!v6rb>9O_D4m0PW=<=?(g}`~1_MX`3zo0Is*N`|Q|%{lh^t zt(Ww^Z!!1TaXQK06)wkc~oe!cmX!&2adVc#RH~#=|&`lR- zPnO|bgZfrGAF`pQ!&o{;qtyD>e^1><*R$%|@AFQ7(0e?5w;t$E-(ofUDjLVoJ4Aa8 z%ea&G*lGQf^x8DT&%TD3WpB8zvWC!9quVC{chMzDPuo~de7PKFMbtm*UHw0G7_85$ zAJyia56Ru7IL7E!)BP(9{>i!!wzi<3{{S`WY5tXn@9dOx)`uHDP95EWc}n}_M6y#(DF{l?{*hY9~NdsbNeox0PL>q4=uc|Fa9nl zbNebfYPu|A{qEQN*#7_xcr?eOEswder3PO9`+mv$$jj|FF!+xo=&NS8Z614NbEtot zMtNVT9ZH`D_Eqz0c(16+xa|7fxdr=q=20f6NqRd)V=UeDew1ebQbf9Lze8p#ZD436w&f0S-z0AG!Zf$k&%ACFulkBOZCs(yWY+T@-}m zuOA-zIUs@DXo?7aY44ztl!Ma0Rd7J zd;HRa`*m)Y18I>;eTR@1g}3hv`QbrB^o>;Ci!>##6LKFX_~a!-{7Gpa;i%Pf3QGcv zW%c*&eDq3|qI~n9myixg7|{xno=3RHrlaJa*-D#P<+EOVli5+o@BJaA-?n4nrrO+? z-{mI~dlpywG`jYdi35-yj*>=bwo0!@oGD?&d=O9Q)wt}`Fx}6zRL(gKDxL~)+uUF* z{*>yU)cbDzP<16s_dw)5dw$(qR>e5Eo+vFCoQj)@+pYxsI&sQJ>(zx(>orGV^-8q5 zXT24*e6`Xef=_eDT#9)j8(?{3T_*`F9{A7CK&H_yWHW)?TCZ;$()ignbRbDCR+i_V zY$7_lf6{8fiTa#(Djug*nd63cD~iN`7s!VvK2x6EgD3~}>s`@Sb)6ufI1+pTD=pGz zb$AtX!~xUcdJ_p zf{nm=X5Z&g&w`v=H)zG8c)_F8s&Jj5ZVR^b+LhfzWreoIG1>DYdkJAIw6h#I%&?w@ z6t>DwJ|RSfp6oXwm9TXhN?NEAROmw^?TC zr`?pQM5TrqZm`V8W4^5N4W%yODmfVm8S>7u6)<;Jx@A*cry-YNoU_nqId58)Vq%rE zsR}ONpJyC+U2<)zzlaUVYL|D?BTJ`R)d=qj)kU(nNotV$MdUQ5OqlIvQe9h( z9-E0;%8|?*)WJw>^up4xk8~|L!m-;UIP7tqIZs&o#$ZjkAG7aXjO9BLWkt$V+{wxU zoNdLF@={7LRmt?_J9EjgX}%p{>#&l&+gpGs8OSCZY1+AUhSar)2D ztf14J3SC2pSF!*`dfZeGPwim&YnfW2)vc?(r*GHk&(bS&`3jFZ45V;&JBjm3hoSuh zcOA8-ZE8<;%UlWVR3?>1w`jD&=4V(s#z{HW3Bc2&rr~E^&@yrI)-;rAb$C^(?esWy z>}eA%_}BQXR;4~B+KyBwDhqV3P4wg^0y+{?Aw+j89FyBQE%C9xNd!S>TecMP0RDHqEPfTne2g)S^{tKEnz5GM}fi z<{4$QW#U#aKEk`89!d(p2RPBx%M$#o!Kt%Hv}=`C;>cy0@gh^6amL(Q7Pf-f1D;Ay z;YlgL0QVX1g}lOOSFKg%yca}*Mp^IB#^IXSTVhqIQKrtOK$svL^6&3fX>qw!s_y39 z)=?JcgD!o$a48g)osfx9+-8{trF0VbNJISIe07knC$h@3$mq?{ZQ7Onw@xLRrCyw? z@&i{5J=$$g*b?#<6(}z>JDb-jMnR3?}D(Wj(OU@y*I(VT+ zDjXecL11wqLj>xso(V+MX2r`)2ob(XjsF0cQ)*(h1viJ$B21o82QSR44czK9sGbvd z9W7f>*?_q7ZZydD1v{47b=bU|lI=A>$yi#Jr7lNbnu$u=N$G>`3vPoK!Ae z8QwF>H}X*(jnZ?8>h%%3$xM13i$$N|!po;vlqc@W4R(4P+`7X&l$hSd4kbb5viphY z!doHPZ6Qh?@h&Idg6qt!- zI#RaM3%uM;5UwO@6E!|_5F|xnY=$DfF#9dYT3dJeNe2r2LD1Q$RAavd*!2laI|)Tl zmeiG@QHfU~TOj2Jkt)h@-@l%x)WR!NYntlD37%(Nb>uYAc0?+8ojOiyn@mU%$aZ_r zX;wDnBEXXqC(kqGMOo9m94jM+peTITPX#l096wF zM^f9Q_^+(C)T(3XdYw&lH%fFZwm*O{o&*7elgF3NRi^=!YIji0|KJo9|=rZO@pE**eOL}aj1cf~z)i|{Tq$MD!M3aPo2fzS4 zhMlp(s#;w2&Z)H{G2GuI)#f8k=7&@n8RUDXo{T{>fNks%S} zM_ZF(p$lokS_3g+tw)c(O1X09gAT44yWeZm;5~KJFUxu&BrT{;Rc23)+Fb}CXl+YS z0bwdR2|2;<*z1oj<=qK23SB03$#qVXC`O-7fg-0WLZ8VF#(B1imCqz4Xix{p5vs~~ zRoAFna0meYvCc^+Ouw?RsfgPhuxJ@ExMh%qIk&>b+iEt!U5j0xMYZI8H5s?76u2$A z>?OM-GV|yOZ6z%w0YsoCX=r(wtA!O| zXiDNaCMylNv;=Vvsj#lf=cvP6ySI`0Wqt+8S-Ols&Z8DiQ2u+{R~Kqz#MpXg?%d`hw}N zJPV2vU1YjkQzQkcJUbLEX$k_jl#-Pcs9{60hro>}QFq*p)??HYR(IQ2qAG~f%0d)V z1xzYAkKH)pT3V6u!Sm03Cj5W{u+{z!+_U#< z+Q+87tqP_>I+b1wTqyC?Ou@%GB;@}9vV;3I6!j;39=p@=>ujg~5}!FI{$QWktmZx^ z1-Ku;R)YgEi_kyB72LoXfRdr>>8_he2|)=RwaodqJ0JV&y;Den6ptLJ11}nMmgeZQERJV-6 zrZ>V9_`q9e#yx(o{_R@jSfw%DR(lrJlkoRWs-MQf3U^)yZo!rmKh{Ho^^Hejqhto(DM z4U|R2cEN}UA|z~skbRjS;i)O%btYelw;w+Dvr*r3KLVnZ{@p~eFFv`{z3_mJw(H%K zyu(z+njd=I&+1N=(VC83lq)-eLZhI1V0!DR#1te0kaVAf`an1!jCLO#bx`5rRHNG- zx?H5fDqd2R@i9_IpnLsa$R&S970 zt%mXlk3wNM2g{SsEB5Ps(6;S#fro%`_eafQI$Z6X5#xzy_s_{j&RveP$PW3@10ZTp z>?X-sm$u$sdz-q~Yqq}JTOQ->Ky?Ci=RyM15C~F&QWSyiAxc`-5;1~x={w8HHZ%=t zvjc#VNzehruRhw@)V!=IQ`eqC2BfIX%Raf+it+2;Q!jCLM`&uk4r)8uv9-$TsI-a6 zja8=C&n1!7LrN*iLXfnVT2?ZosVX=ijVdiIn(*sH@FijH&hIMcZFRMDpIsty)C;S1 z@Y+(+5s)}D+~?2l)hX!(n4l_#xRULSou6B3CL^(2w*k5}%`)bI8@07QuFS;Bft>zh zS;^F!-4ysoPRvdlMw`jY*cj2ZWFtGi3HcTCjSb2Er= zgWyUX!I(qP1$EdW=_lH`UYQiHkPK ztxj^ENTyJmafnv{ixq|%r@%yqGF?Mp1NQK9?)T0+N7|;;-t=1`Oq8upVV9>`W3;o|kS?N1H1+&SB=K(%BV+==rk zvMv?Jc_}mCEhTNFW07Vfz-PY|f#kET$|cufHR4k7xT{w^vwO;!Pk|PXAvN}xbjW7g zb*T?3px_QPlHv?;pLDqTR&t~bMY}7**5TiFohGqGT`{P&CQGmzkuBb{=^oRNls3}O zFW_(CK4$|N&Wxt`bKHu>^K;SSB9yi$6=c%t6;~<>Ofw*Wr6Hvhb9kX7sX-|00V6%N z&TH9WZ?9SIxPs}CmQqK4xh=groAE}ITS7pH*Ek$B(={(@+p)G#yjW6hXti5ryGmxs zt)aEk(GZ8{@F&?{Q;lV$p|(=fPXwsQQqLT=qK!vg+Pk*qKVQs;3&{npvhU zp2RldvLeK7N>d4t>1itB3OZp8seN)AY+z|F?}pp`oAyNk>Xp8iTBXUVCQ7G1T6QdT zaa)N6%rg!XR<>kIeWcFD|d&N#d<2X_@7xkVvvdeD&;&6@!gqzPGYg zH)JEVjE|0s^wxApKt^lY2P5&;LG6-sNTEu<3imZ8-@UJaS*p@#u~$TG_MfT9ZH}v` zlG+l|R^oz3y6V?}3cbOC26FsdwfhR^@keSkRjyhSXiBpmadYHW(9&3tw2klU2$m?sH=WMs~p@6t!>6{&2KnRl@zvwg(*tG!ctFs zV@j%a*<3XZ-MFlk%%ak(G-hT!F?A}lapJ8J0aMCt?&)v@B}*9dKqu2!%j`1Jt>f+= zS7qBi0wPwb()#I^t)R+VLVQlJtT%8k3v!`)?KO3}tva1ihf=8(KGjN|`_MwD(*5Jn zc}OZ#p|t`|6tIv;IMZn(cW5_`)6=ecm2$BTgHF1w$Wg`1A{^J>5S(9HYiu&KK9`+o zL2ozwT2;iQ#SDQQ?r#mdgJ-IR^z*j-BvvXPN5gPblu~3C-Ej@5^o9z{3qnZ9=2Maa zg0;)9>@DeTTkKY9lkIfVR`g~al-j3r&uSEjDX$Qy=L(z+*hdARj06H$2s|} zjc3zZzR-d&Icp*0j1!ah4QQZ^%9~-!XI{|(_WpmHSr505Gx@ZrJP~X;9jPbb8DSnj zEV1ggCZK*XR)1!rT2C}`PnM)Vg%^5%x&ybPBy}8=jE|g0f5TK*+h+#2AHeT|xxxKh zR)_Pq;;%*;&>~pS=SR|(dw!DbFK8yrp zcl>j#1NFzyV?g9H@Yjln?~H(S5LsCu6Rkxk;Pxlwqa0tjap$cCysmzDIP=jUvhG7f zdgB`4!q2d{l@BEle?Gd-yh=#;4*vi>aZ*YicZu^+Cqz<$)~N)pGl}y_NeAWIzN-7i zy94r`mmq*t7>u4@FRj&{OCd=BdYokHw7hIIM4DJDNm{A#oe4^_*gE2JKOBRmPO3u{ zdz@#_LdeFo$X74>_23nsJp}w@y0q64iu5z<)gD?~7hBJteK+tyOX2(XJDALY`cJwwH;H|N-A7|GXo#cN+uv_6K-x`!dAT=R*J3yfiOtm<74N`Auz zXI%MZ#Tmcbp=R4#MH9nr%YfyHNRk(FS4)A#)}|q0N=R7jq=!-zw+FhVAt^=>*ySMV zm93|>e{#oRTHYH@qj+1=;l(iEQyeH-B=jD5I-}KVutL^6&vg#ox`$hxx|=U_?F07x zgRoFG?zp*ZINfQfh;eN+f=b#6N&)x%(dXnfsk%Gn#^U&D+U>fFYdJ$|&tl!ec_Y^JZ^>1WUE_p5+M=xqYI-G=PwTBX;LbZCGV_hr_ zU;ymv?7wkbl6*?RTvm0>Q&GSIcnBT=2RT0tS(egomfKLa{{Uuf8l_fk zEz@nY%0!ydr!gJ$R9sWYkd->xVuF~BExf3Nw&QJv)H|BU3mr0|a50i}@3FEMvy|Ff z^5t7p7iY#Y5fZ(1!=}h-A6BNy+kI&uxQCl@0B4dFp54Z%@ta#a+D&E4_5_kMfJ zUB%t4S3cdFM?>PJZ9jVK%G@o)dYfIQsy%LXAuwqdT=Y{txrlVeBgkRYxgJVcZSWpx zOI(l&2Mmy<^*X=4J&j25o3zy#7i}J~Zc><#hAM2yhKpY`dc8HeyG7V4OG4kXOxkWF(|0IG~k;dFwQA z%o(mRTmiCojW%L@Jbo9V{{R$lAEM4jjLb)Baqi8`+Z(sR#rR{lQKnrrI>jkzrG_BH zDtSKFAx<8dDw4}ebhfvO8*xLR4{%PR0*}0O_tktoZe1SRwdxb;%(CTHqZXe&5iWJ@ z`4Y-~k-CaP4?w(?g(MDG`zqqRCtY{EFK&`4*8}WEgHv1x=z5yfJ?CCkf9eq>&^VmN=Ix35*IE~}MeIg4+;o4MoBz`)^d+WI# zfih0nS1!)84AQBR;>2H3L@oEE2Ue6e0Kq9k4lP+D2?NZKH2uf$%T{trOvv2JK0E7u z3@wB=Q+38xs9xSVfeZvE*|PUgklzqbBaK|wzqGgi0CDl-$nNb5vX!01Z(+0Gqd7hG z?LQU=hkNI7Hs$-Ol{SY@qq!O7(ic#yFw2!kn|q=fLu*h|VMI7Xivd4Yq1^VgG@(N=<64~UY{A)*{<4xpvP5|x675^@eZ=u*D6ty$Emycai7ydhH3;*B9y z?IEW^lexyS4}RWy3C=Z0=K(z>MTmEMXcUaD)e@ohT(b0ti&xrYcF*g zY?7daB)9cxN{Y!G0$`4Ov!R3ERFyh0WmWHcoaugNF~)c+_G1@JV_MKXT{7hgepUQS zS@HLhv3Jr^)T-T6Zsx_7^i<1k8ftGyT!hEqTWN2OCANdk533xATFw`yGk3R6)Z8ta zv0>9FGGN@0>NOb^7pGNTmis7hvWC)JEwM2hTwS_-uCQQ}*)gt3j+bxbz!=p%Sf9jLM#q8G#;30#MLWLy0cD z($q2z85q*;(%oww#JVkd4LZ49pJ&i*8jUqIc+?gI=~C)QNRpH{A!RJZZNMLGDhgjP z;O>`j8G9y-`1ZM|M9Bq@K z)~NJay`Onhu4)}Z9c6AycG^r8Kz5#n2vg`>m0~)VDWt8?ud4*8q=cMgFKJKt}T2nu~g4CO9yq^ECyG_VSXmpsnCuN7?ixWj|VVCUdD@J^N2zuRqH*sIdpr%9?XW5B4w zpeCmkOOa9%Q6aY017Nou2~2^Iw>fcbASC)gGIFAP`|4@leiL4`HlDX!X7^QEvqgs@ zOKs)$Pc22wu|peIC^=a$_{g#D1%&WFj)b4G;wCy#)q*P|bk2a*}abgWfQ{Hvp8a+7uM7TPDsaffcI8h@-^xho= z?dW5FmTItJ*Nz|t9P{1gqjFDe2T%KcAFGgS_u6i{j4N*2wIsU&rE&+0@KZE4*=3=O zIHuOO(o*)8vffl7RCkJ+)*{yiAtcBmvb(IUhV~2Kd5$Jwt3QS*#$V4a{Xg zf;+bA5U=)-e_om=B>2fz+S(kmpoa6`N3Z3qC(LKhPNeM+jgF2pqpLnnB>s`74QUx8 z8SkV5P8>qM3Lmda)w$fCB2I!#ld4Lk%=QCTXW@~f8^VpgzBo(sAAL&rePQ#}XeFNg zh#!Wn3t|{de+sr!Ado$TuPHb=EAOPcM3JhKbonV< zzLP-$9F?%rl(64ytoJ_et~J&{!%V0gb547DX^>OFV&HuUwdX9W7-e1m01RrCH=0kh z__N$s2av8h(EE@QkM8Q_E<>AQLR53li2O1&Nn3{crFek8|HP}DH z=51mXpTVQve+A3BaPRaqo5kM%l3;_Z=mYHy3qA3#L9P};&b_Z|-kLN^Iqk0%dxZS8 zjRi;Q$sY|Wa!L?Zy7FK;l#Hmfpa;a`QYQI!&D<{=Q|3p$=yVHFs5F!}TdFlD+1*S? zF&S~R2ANxhc|(C9qyxbB9f^-g86Xc{+I#2hWyMX}+J)5pYMiE_P+O|I8L;6|O6v~N zLt#V=o}(&osDbW~F`VL-^Qmjt_q?=_S~8z%IUXH~+S{!yEo)>*Cq4>CQa96Xsu$JS zGOKsarp9-3Yqk1gZ_}0R%BIO>32iIG4=lR3QkPq4D9TiVqPECE(k2%C+Hw3t-?u#$ z8Y#`rvFa4+Z9?LTsYx^4880FpT@8r}>4$oYdB+24Q;sQa9C7x4!bRh|`(^oAPP(Yn zN@}I9DjSm9&b(HMNMU4zfB^T3hjW|&K+Zy1>G0`sT3gY0Syd~Bocziamb;_cVzl9I zOFdMyhH$JDER}*!yb=d)TFMJ76jpsQ@<n7BMw~*@sh@vPGhWrD#uaIF0IMbmI`1ZbvaO+s zTw9W9#$XQMrp=4Kdr5BWg|R}b$-ADc!j9>fxK{K&iFw&GWGjd8r1lD!kT8%wtFQpz zM$Or|t^3Es1$yMR1wk}fh>RFTQafzEj+9oIAf>gfIQp!E%gFZwO~!j2V~>9Q^s8(# z-qPXK*iO-%PFd%{6?{Uy#Q{m+K#y|I;hkYe2|lM;$H3^d_GwBzz)o@Z4PG(ZE(e0E z8mw~GbKBvmA3JGw-q)zwbc@pMx!J8%VnlJ{L!!lz9V?iK@30bHeYFIosVPeG3FeYh z*nke8Y~`Ig=iDgu3e9Oyw(7A{{HY8y!W^i;VncD`h=hV0eQv0S3xP;N4>X05(=ZQ6 z>;NY`xhtWDT$q3}=A&MIdKWECdt1U?KC9xcv1wH<$5Cv=r$ej6TzFv1lMZVSJi;7N zP)kmrGy?dArCsri_HPy8;c&ca*;`FA=}g#_noN4ks-qA(4mS}@J0;YiqP32wZH0v8 zTuPLjkTaz>q!@eNS`J~mS{)*Ia7lIGn|Z^j z)v2_3dxE=ay33)LWFePY5~4jH2?;9dWChDVH;SdDsph!B; zEeFtR+P_p#5VNmx?eoxe?F1ktts3n`D`l_iOuf2M;xT@)VO@6HzsP8CrT5p3$41$z*QO*>tBkW2A zt0Qdp67aiYPpQDSqeEqw6!}6j+;(%WFx$OCm1Cr|p?r9dp^$wv@mo7xxBG)_TQ&Pm z+^SKa%xb>PP90+EGFH)u+$1q2mm7+VsV>e+iz`nRr34ixAnPr8QoudUs@D(>aB`iF z(0!gzv{uZ!>e^V~GRur@c5zxYUlR2RJIlA#+BLkIH5$~bB4a2oHtj-62;lpDP}*B% zF1A(*T5TYyeAcHF0*5x>-ueZ}u+nVGjiXk6RG|n6WnP zAhM+WcGlS{^{IKv+xl(MNu6z9wAyV!6^Tu{+tFpdxP_KV6iRrmGp>cND?L!0;XUct z{0-e%Z8OVUpzqIlJo&3PQ^lq7@z29{TmtW2-X^IRuHD`gDK6%8%c)P^vr4myOn zmjapz_X}q^)4{tQ9yW(>EnDHyY)(|T6uN`3*4@V?-i(IhvZ&y4z*0HA$z?%AfKmsp zgF^7Pwzba((688#sLwLtB`=(o{se+ z^iJ6)-}boxrn-hT)}PWgK%VaoYB$q2+BARO?ljAKt|_*?B-39M>y;7SsW3xl+15-_ zgfOQX3Q`JIuKkt|y`yNJCif1FPP^bUQ+h2jr3$DiF(Ed>Q%sjSceW|=rM9Iy5~Vtj zl^-RKhMrp&!VF7`!-auz)a#8lYvC0-y+wlOwA13KEw&wO2ng=8 z3vUZ%WtY1*J)cpK#JTh;tcrA(;{|k`TP_q8r2WMA zCzr7_f|7zb=Z_reuDtH47mngit-BGrWHn)mWigeNvrUHHM0kx8+_us%gy+5o;iRU$ z>#Zdl8(VsvWhTM*ZqR5n3X-MO#b#_h(mnnV26#9aFPI0Xyxy7#P0YmibY2uX-PI3;k+zCm>JLn*YbxPr!{yOG3=t{He z>yIsOxkCvwc>IKv!=j|jP&!|)4PFvMa3a}-}?E6!v wI@&D)0nWYC1dmNVm?dk62tHoAT@#Qq@X|R=r-9PJI4q!_=8XPr1X92M*{vw= 3.10** -- **paddlepaddle-gpu 要求版本develop** +- **paddlepaddle-gpu 要求3.0.0b2或版本develop** ``` -# 安装示例 +# develop版安装示例 python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html ``` -- paddlenlp >= 3.0.0(默认开启flash_attn,推荐源码编译安装) +- **paddlenlp == 3.0.0b2** + +> 注:(默认开启flash_attn)使用flash_attn 要求A100/A800显卡或者H20显卡。V100请用float16推理。 -> 注: -* 请确保安装了以上依赖,否则无法运行。同时,需要安装 paddlemix/external_ops 下的自定义OP, `python setup.py install`。如果安装后仍然找不到算子,需要额外设置PYTHONPATH ## 3 推理预测 -1. plain texts OCR: +### 3.1. plain texts OCR: ```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type ocr +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ + --model_name_or_path stepfun-ai/GOT-OCR2_0 \ + --image_file paddlemix/demo_images/hospital.jpeg \ + --ocr_type ocr \ ``` -2. format texts OCR: +### 3.2. format texts OCR: ```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ + --model_name_or_path stepfun-ai/GOT-OCR2_0 \ + --image_file paddlemix/demo_images/hospital.jpeg \ + --ocr_type format \ ``` -3. fine-grained OCR: +### 3.3. fine-grained OCR: ```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format/ocr --box [x1,y1,x2,y2] +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ + --model_name_or_path stepfun-ai/GOT-OCR2_0 \ + --image_file paddlemix/demo_images/hospital.jpeg \ + --ocr_type ocr \ + --box [x1,y1,x2,y2] \ ``` + ```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format/ocr --color red/green/blue +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ + --model_name_or_path stepfun-ai/GOT-OCR2_0 \ + --image_file paddlemix/demo_images/hospital.jpeg \ + --ocr_type ocr \ + --color red \ ``` -4. multi-crop OCR: +### 3.4. multi-crop OCR: ```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --multi_crop --ocr_type format/ocr +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ + --model_name_or_path stepfun-ai/GOT-OCR2_0 \ + --image_file paddlemix/demo_images/hospital.jpeg \ + --multi_crop \ + --ocr_type ocr \ ``` -4. render the formatted OCR results: ```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format --render +# render the formatted OCR results: +python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ + --model_name_or_path stepfun-ai/GOT-OCR2_0 \ + --image_file paddlemix/demo_images/hospital.jpeg \ + --multi_crop \ + --ocr_type ocr \ + --render \ ``` + ## 参考文献 ```BibTeX @article{wei2024general, diff --git a/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json b/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json index e728195a8..3fe8acb5e 100644 --- a/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json +++ b/paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json @@ -1,6 +1,6 @@ { "synthdog_en": { - "images": "playground/data/synthdog-en/", - "annotations": "playground/opensource/synthdog_en.jsonl" + "images": "synthdog_en/", + "annotations": "synthdog_en/synthdog_en_29765_ocr_1k.json" } } diff --git a/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py b/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py index cba95bb7c..d71f5eac1 100644 --- a/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py +++ b/paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py @@ -13,38 +13,25 @@ # limitations under the License. import argparse - import paddle from paddlenlp.transformers import QWenTokenizer - -from paddlemix.models.GOT.model import GOTQwenForCausalLM +from paddlemix.models.GOT.GOT_ocr_2_0 import GOTQwenForCausalLM parser = argparse.ArgumentParser() - -parser.add_argument("--model_name_or_path", type=str, default="GOT-OCR2_0_pd", help="pretrained ckpt and tokenizer") -parser.add_argument("--image_file", type=str, default="yiyuan.jpeg") +parser.add_argument("--model_name_or_path", type=str, default="stepfun-ai/GOT-OCR2_0", help="pretrained ckpt and tokenizer") +parser.add_argument("--image_file", type=str, default="paddlemix/demo_images/hospital.jpeg") parser.add_argument("--multi_crop", action="store_true") parser.add_argument("--ocr_type", type=str, default="plain", choices=["ocr", "format"]) parser.add_argument("--box", type=str, default="") parser.add_argument("--color", type=str, default="") parser.add_argument("--render", action="store_true") - args = parser.parse_args() model_name_or_path = args.model_name_or_path tokenizer = QWenTokenizer.from_pretrained(model_name_or_path) -# print('tokenizer:\n', tokenizer) -# print('tokenizer.added_tokens_encoder:\n', tokenizer.added_tokens_encoder) -# print('tokenizer.added_tokens_decoder:\n', tokenizer.added_tokens_decoder) -# PretrainedTokenizer(name_or_path='', -# vocab_size=151851, model_max_len=8000, padding_side='right', -# truncation_side='right', special_tokens={ -# 'pad_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False)}) model = GOTQwenForCausalLM.from_pretrained( model_name_or_path, dtype=paddle.bfloat16, pad_token_id=tokenizer.eos_token_id ).eval() -# print('tokenizer:\n', tokenizer) - # input test image image_file = args.image_file diff --git a/paddlemix/examples/GOT_OCR_2_0/run_train.sh b/paddlemix/examples/GOT_OCR_2_0/run_train.sh new file mode 100644 index 000000000..b1ec2d19e --- /dev/null +++ b/paddlemix/examples/GOT_OCR_2_0/run_train.sh @@ -0,0 +1,78 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -x + +GPUS=${GPUS:-8} +BATCH_SIZE=${BATCH_SIZE:-8} +PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-1} + +GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS)) +tensor_parallel_degree=${tensor_parallel_degree:-1} +sharding_parallel_degree=$((GPUS / tensor_parallel_degree)) + +export PYTHONPATH="${PYTHONPATH}:$(pwd)" +export MASTER_PORT=34229 +export TF_CPP_MIN_LOG_LEVEL=3 + +OUTPUT_DIR='work_dirs/got_ocr_20' + +# meta='pdf-ocr+scence' + +if [ ! -d "$OUTPUT_DIR" ]; then + mkdir -p "$OUTPUT_DIR" +fi + +TRAINING_MODEL_RESUME="None" +TRAINER_INSTANCES='127.0.0.1' +MASTER='127.0.0.1:8080' + +TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes 1 --nproc_per_node ${GPUS} --rank 0 --ips ${TRAINER_INSTANCES} --run_mode=collective" +${TRAINING_PYTHON} --log_dir ${OUTPUT_DIR}/paddle_distributed_logs \ + paddlemix/examples/GOT_OCR_2_0/train_GOT.py \ + --do_train \ + --model_name_or_path "stepfun-ai/GOT-OCR2_0" \ + --output_dir ${OUTPUT_DIR} \ + --logging_dir ${OUTPUT_DIR}/logs \ + --meta_path paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json \ + --overwrite_output_dir True \ + --dataloader_num_workers 8 \ + --bf16 True \ + --fp16 False \ + --fp16_opt_level "O2" \ + --num_train_epochs 1 \ + --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \ + --gradient_accumulation_steps ${GRADIENT_ACC} \ + --freeze_vision_tower False \ + --use_im_start_end True \ + --max_seq_length 8192 \ + --recompute False \ + --max_grad_norm 1.0 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 200 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.001 \ + --optim "adamw" \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --report_to "visualdl" \ + --tensor_parallel_degree=${tensor_parallel_degree} \ + --sharding_parallel_degree=${sharding_parallel_degree} \ + --pipeline_parallel_degree=1 \ + --sep_parallel_degree=1 \ + --sharding="stage1" \ + 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" diff --git a/paddlemix/examples/GOT_OCR_2_0/train_GOT.py b/paddlemix/examples/GOT_OCR_2_0/train_GOT.py new file mode 100644 index 000000000..9fdee3c86 --- /dev/null +++ b/paddlemix/examples/GOT_OCR_2_0/train_GOT.py @@ -0,0 +1,256 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import paddle.distributed as dist +import paddle +import paddlenlp +from paddlemix.datasets.got_dataset import make_supervised_data_module +from paddlemix.models.GOT.GOT_ocr_2_0 import GOTQwenForCausalLM +from paddlenlp.trainer.trainer_utils import get_last_checkpoint + +from paddlemix.models.GOT.utils.utils import smart_tokenizer_and_embedding_resize +from paddlenlp.trainer import PdArgumentParser, TrainingArguments, set_seed +from paddlenlp.trainer.trainer import Trainer +from dataclasses import dataclass, field +from typing import Dict, Optional +from paddlenlp.transformers import QWenTokenizer +import logging +logger = logging.getLogger(__name__) + + +def print_trainable_params(model: paddle.nn.Layer) -> None: + trainable_params, all_param = 0, 0 + for k, param in model.named_parameters(): + num_params = param.size + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + all_param += num_params + if not param.stop_gradient: + # print('{}, shape: {}, requires grad: {}'.format(k, param.shape, not param.stop_gradient)) + trainable_params += num_params + print( + "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) + ) + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="stepfun-ai/GOT-OCR2_0") + use_cache: bool = field(default=False) + vision_tower: Optional[str] = field(default="openai/clip-vit-large-patch14") + freeze_vision_tower: bool = field(default=False) + freeze_lm_model: bool = field(default=False) + pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower + vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + use_im_start_end: bool = field(default=False) + + +@dataclass +class DataArguments: + datasets: str = field(default=None, metadata={"help": "combinations of the training data."}) + meta_path: Optional[str] = field( + default=None, + metadata={"help": "The path of the meta file of datasets."}, + ) + sep_image_conv_front: bool = False + image_token_len: int = 256 + image_aspect_ratio: str = 'square' + conversation_version: str = 'mpt' + box_limit: int = 0 + max_seq_length: int = 8192 + + +@dataclass +class TrainingArguments(paddlenlp.trainer.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + force_fsdp: bool = field(default=False) + interleave: bool = field(default=False) + with_box: bool = field(default=False) + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + lora_enable: bool = False + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + + +def train(): + # parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + # model_args, data_args, training_args = parser.parse_args_into_dataclasses() + parser = PdArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script, and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Detecting last checkpoint and eventually continue from last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Load model + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + elif training_args.bf16 and paddle.amp.is_bfloat16_supported(): + dtype = "bfloat16" + else: + raise ValueError("Please specific dtype: --fp16 or --bf16") + else: + dtype = "float32" + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Load pretrained model, tokenizer, and image processor + tokenizer_path = model_args.model_name_or_path + print(f"Loading Tokenizer: {tokenizer_path}") + + tokenizer = QWenTokenizer.from_pretrained( + model_args.model_name_or_path, + padding_side="right", + model_max_length=training_args.model_max_length) + print("tokenizer", tokenizer) + print("len(tokenizer)", len(tokenizer)) + print("tokenizer.added_tokens_encoder", tokenizer.added_tokens_encoder) + print("tokenizer.added_tokens_decoder", tokenizer.added_tokens_decoder) + + model = GOTQwenForCausalLM.from_pretrained( + model_args.model_name_or_path, dtype=dtype) + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token='<|endoftext|>'), + tokenizer=tokenizer, + model=model, + ) + + vision_tower_dict = model.get_model().initialize_vision_modules( + vision_tower=model_args.vision_tower, + pretrained_stage1_model=model_args.pretrained_stage1_model, + freeze_vision_tower=model_args.freeze_vision_tower, + use_im_start_end=model_args.use_im_start_end, + vision_select_layer=model_args.vision_select_layer, + dtype=dtype, + ) + + model.initialize_vision_tokenizer( + tokenizer=tokenizer, + freeze_lm_model=model_args.freeze_lm_model, + pretrained_stage1_model=model_args.pretrained_stage1_model, + ) + + # 'image_processor_high + # data_args.image_token_len = vision_tower_dict['image_token_len'] + data_args.image_token_len = 256 + data_args.image_processor = vision_tower_dict['image_processor'] + data_args.image_processor_high = vision_tower_dict['image_processor_high'] + data_args.use_im_start_end = model_args.use_im_start_end + + def _freeze_params(module): + for param in module.parameters(): + param.stop_gradient = not False + + # mixed relation, to be fixed + if model_args.freeze_lm_model: + _freeze_params(model.get_model().mm_projector) + _freeze_params(model.get_model().mm_projector_vary) + _freeze_params(model.get_input_embeddings()) + + if model_args.freeze_vision_tower: + _freeze_params(model.qwen2.vision_tower_high) + + # params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad] + # print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") + print_trainable_params(model) + # trainable params: 464959488 || all params: 560528640 || trainable%: 82.9502 + + params_grad = [p.numel() for n, p in model.named_parameters() if not p.stop_gradient] + print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") + + # print trainable parameters + if dist.get_rank() == 0: + for name, param in model.named_parameters(): + if not param.stop_gradient: + logger.info(name) + + # set seed for paddle dataloaders + set_seed(training_args.seed) + + data_module = make_supervised_data_module( + interleave=training_args.interleave, + with_box=training_args.with_box, + tokenizer=tokenizer, + data_args=data_args + ) + + #trainer = GOTTrainer( + trainer = Trainer( + model=model, + args=training_args, + tokenizer=tokenizer, + **data_module, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + try: + metrics["train_samples"] = len(data_module["train_dataset"]) + except: + metrics["train_samples"] = -1 + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + +if __name__ == "__main__": + train() diff --git a/paddlemix/models/GOT/model/GOT_ocr_2_0.py b/paddlemix/models/GOT/GOT_ocr_2_0.py similarity index 98% rename from paddlemix/models/GOT/model/GOT_ocr_2_0.py rename to paddlemix/models/GOT/GOT_ocr_2_0.py index 4bfb2cad2..66ae85824 100644 --- a/paddlemix/models/GOT/model/GOT_ocr_2_0.py +++ b/paddlemix/models/GOT/GOT_ocr_2_0.py @@ -13,15 +13,12 @@ # limitations under the License. from io import BytesIO -from typing import List, Optional # , Tuple, Union +from typing import List, Optional import paddle import paddle.nn as nn - -# import paddle.nn.functional as F -# import paddlenlp import requests -from paddlenlp.generation.stopping_criteria import ( # , TextStreamer; StoppingCriteria, +from paddlenlp.generation.stopping_criteria import ( StoppingCriteriaList, ) from paddlenlp.transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Model @@ -34,11 +31,10 @@ DEFAULT_IM_END_TOKEN = "" import dataclasses - +from enum import Enum, auto from paddle.vision import transforms - -from .plug.blip_process import BlipImageEvalProcessor -from .vision_encoder.got_vision_b import build_GOT_vit_b +from ...processors.got_process import BlipImageEvalProcessor +from .got_vision_b import build_GOT_vit_b class Qwen2LMHead(nn.Layer): @@ -69,9 +65,6 @@ def forward(self, hidden_states, tensor_parallel_output=1): return logits -from enum import Enum, auto - - class SeparatorStyle(Enum): """Different separator style.""" @@ -423,13 +416,14 @@ def forward( shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] # Flatten the tokens - # loss_fct = nn.CrossEntropyLoss() + #loss_fct = nn.CrossEntropyLoss() loss_fct = nn.CrossEntropyLoss(reduction="sum") shift_logits = shift_logits.reshape([-1, self.config.vocab_size]) shift_labels = shift_labels.reshape([-1]) # Enable model parallelism + loss = loss_fct(shift_logits, shift_labels) - label_sum = paddle.sum(shift_labels != -100).cast("float32") + label_sum = paddle.sum(shift_labels != -100) #.cast("float32") loss = loss / label_sum if not return_dict: diff --git a/paddlemix/models/GOT/__init__.py b/paddlemix/models/GOT/__init__.py index fd05a9208..d6794b65d 100644 --- a/paddlemix/models/GOT/__init__.py +++ b/paddlemix/models/GOT/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .GOT_ocr_2_0 import GOTConfig, GOTQwenForCausalLM, GOTQwenModel diff --git a/paddlemix/models/GOT/data/__init__.py b/paddlemix/models/GOT/data/__init__.py deleted file mode 100644 index 061cac8b2..000000000 --- a/paddlemix/models/GOT/data/__init__.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from functools import partial -from sys import meta_path -from typing import List, Union - -import paddle -import paddlenlp -from paddle import Tensor - -from paddlemix.models.GOT.data.conversation_dataset_qwen import ConversationDataset - -from ..utils.constants import * - -IGNORE_INDEX = -100 - - -# helpers -def pad_sequence_paddle(sequences, padding_value=0): - """ - Implement a function similar to PyTorch's pad_sequence in PaddlePaddle. - - Args: - - sequences (list of Tensor): The list of sequences to be padded. - - padding_value (float, optional): The value used for padding, default is 0. - - Returns: - - Tensor: The result of padding all sequences to the same length. - """ - # Calculate the maximum length - max_len = max([seq.shape[0] for seq in sequences]) - - # Pad sequences - padded_sequences = [] - for seq in sequences: - # Calculate the length to pad - padding_len = max_len - seq.shape[0] - - # Create a padding tensor - if padding_len > 0: - padding_tensor = paddle.full([padding_len] + list(seq.shape[1:]), padding_value, dtype=seq.dtype) - # Concatenate the original sequence and the padding tensor - padded_seq = paddle.concat([seq, padding_tensor], axis=0) - else: - padded_seq = seq - - padded_sequences.append(padded_seq) - - # Stack the padded sequences to form a batch - padded_batch = paddle.stack(padded_sequences, axis=0) - return padded_batch - - -def orig_pad_sequence( - sequences: Union[Tensor, List[Tensor]], - batch_first: bool = False, - padding_value: float = 0.0, -) -> Tensor: - if batch_first: - return pad_sequence_paddle(sequences, padding_value) - else: - assert False, "Not implemented" - - -@dataclass -class DataCollatorForSupervisedDataset(object): - tokenizer: paddlenlp.transformers.PretrainedTokenizer - - def __call__(self, instances): - input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) - images = [paddle.stack(instance["image"]) for instance in instances] - images_high = [paddle.stack(instance["image_high"]) for instance in instances] - images = list(zip(images, images_high)) - - pad_sequence = partial(orig_pad_sequence, batch_first=True) - - input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) - - labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) - - batch = dict( - input_ids=input_ids, - labels=labels, - attention_mask=input_ids.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)), - images=images, - ) - return batch - - -def make_supervised_data_module(interleave, with_box, tokenizer, data_args): - assert data_args.conversation_version == "mpt" - - train_dataset = ConversationDataset( - tokenizer=tokenizer, - # datasets=data_args.datasets, - meta_path=data_args.meta_path, - multimodal_cfg=dict( - sep_image_conv_front=data_args.sep_image_conv_front, - image_token_len=data_args.image_token_len, - image_aspect_ratio=data_args.image_aspect_ratio, - use_im_start_end=data_args.use_im_start_end, - image_processor=data_args.image_processor, - image_processor_high=data_args.image_processor_high, - box_limit=data_args.box_limit, - ), - ) - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) diff --git a/paddlemix/models/GOT/data/base_dataset.py b/paddlemix/models/GOT/data/base_dataset.py deleted file mode 100644 index c8e2144c6..000000000 --- a/paddlemix/models/GOT/data/base_dataset.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# import copy -# import io -# import json -import logging - -# from typing import Dict, List, Optional, Sequence, Tuple, Union -from typing import Dict - -import paddle -import paddlenlp -from paddle.io import Dataset -from PIL import ImageFile # , Image - -ImageFile.LOAD_TRUNCATED_IMAGES = True -# from ..utils.constants import * - - -class BaseDataset(Dataset): - def __init__(self, datasets: str, tokenizer: paddlenlp.transformers.PretrainedTokenizer, multimodal_cfg: dict): - super(BaseDataset, self).__init__() - self.tokenizer = tokenizer - self.multimodal_cfg = multimodal_cfg - - logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image") - - def image_processor(self, image): - # processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit) - processor_high = self.multimodal_cfg[ - "image_processor_high" - ] # the second processor, usually is the designed image encoder (sam/swin/cnn) - image_high = image.copy() - - # Vary old codes - - # # TODO the 'keep', 'padding' only used for the first processor - # if self.multimodal_cfg['image_aspect_ratio'] == 'keep': - # max_hw, min_hw = max(image.size), min(image.size) - # aspect_ratio = max_hw / min_hw - # max_len, min_len = 448, 224 - # shortest_edge = int(min(max_len / aspect_ratio, min_len)) - # image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0] - # elif self.multimodal_cfg['image_aspect_ratio'] == 'pad': - # def expand2square(pil_img, background_color): - # width, height = pil_img.size - # if width == height: - # return pil_img - # elif width > height: - # result = Image.new(pil_img.mode, (width, width), background_color) - # result.paste(pil_img) # for simpler box processing - # return result - # else: - # result = Image.new(pil_img.mode, (height, height), background_color) - # result.paste(pil_img) # for simpler box processing - # return result - # image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) - # image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0] - # else: - # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] - - image_high = processor_high(image_high) - - return image_high - - def __len__(self): - return len(self.list_data_dict) - - def __getitem__(self, i) -> Dict[str, paddle.Tensor]: - pass diff --git a/paddlemix/models/GOT/model/vision_encoder/got_vision_b.py b/paddlemix/models/GOT/got_vision_b.py similarity index 94% rename from paddlemix/models/GOT/model/vision_encoder/got_vision_b.py rename to paddlemix/models/GOT/got_vision_b.py index 38e25f495..a005a7f48 100644 --- a/paddlemix/models/GOT/model/vision_encoder/got_vision_b.py +++ b/paddlemix/models/GOT/got_vision_b.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# import math from functools import partial from typing import Optional, Tuple, Type @@ -20,31 +19,6 @@ import paddle.nn as nn import paddle.nn.functional as F -# class Projector(paddle.nn.Layer): - -# def __init__( -# self, -# width: 256, -# n_queries: int = 256, -# output_dim: int = 4096, -# **kwargs -# ): -# super().__init__() - -# norm_layer = partial(paddle.nn.LayerNorm, epsilon=1e-06) -# self.attn_pool = Resampler(grid_size=int(math.sqrt(n_queries)), -# embed_dim=output_dim, num_heads=output_dim // 128, kv_dim=width, -# norm_layer=norm_layer) -# self.ln_post = norm_layer(output_dim) -# self.proj = paddle.base.framework.EagerParamBase.from_tensor(tensor -# =output_dim ** -0.5 * paddle.randn(shape=[output_dim, output_dim])) - -# def forward(self, x: paddle.Tensor): -# x = self.attn_pool(x) -# x = self.ln_post(x) -# x = x @ self.proj -# return x - class MLPBlock(paddle.nn.Layer): def __init__(self, embedding_dim: int, mlp_dim: int, act: Type[paddle.nn.Layer] = paddle.nn.GELU) -> None: @@ -451,13 +425,12 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor: return x -def build_GOT_vit_b(checkpoint=None): +def build_GOT_vit_b(): return _build_GOT_vision( encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=[2, 5, 8, 11], - checkpoint=checkpoint, ) @@ -466,7 +439,6 @@ def _build_GOT_vision( encoder_depth, encoder_num_heads, encoder_global_attn_indexes, - checkpoint=None, ): prompt_embed_dim = 256 image_size = 1024 @@ -487,4 +459,4 @@ def _build_GOT_vision( out_chans=prompt_embed_dim, ) - return image_encoder + return image_encoder \ No newline at end of file diff --git a/paddlemix/models/GOT/model/__init__.py b/paddlemix/models/GOT/model/__init__.py deleted file mode 100644 index d6794b65d..000000000 --- a/paddlemix/models/GOT/model/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .GOT_ocr_2_0 import GOTConfig, GOTQwenForCausalLM, GOTQwenModel diff --git a/paddlemix/models/GOT/model/vision_encoder/__init__.py b/paddlemix/models/GOT/model/vision_encoder/__init__.py deleted file mode 100644 index fd05a9208..000000000 --- a/paddlemix/models/GOT/model/vision_encoder/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/paddlemix/models/GOT/model/vision_encoder/vary_b.py b/paddlemix/models/GOT/model/vision_encoder/vary_b.py deleted file mode 100644 index 213571535..000000000 --- a/paddlemix/models/GOT/model/vision_encoder/vary_b.py +++ /dev/null @@ -1,487 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# import math -from functools import partial -from typing import Optional, Tuple, Type - -import paddle -import paddle.nn as nn -import paddle.nn.functional as F - -# class Projector(paddle.nn.Layer): - -# def __init__( -# self, -# width: 256, -# n_queries: int = 256, -# output_dim: int = 4096, -# **kwargs -# ): -# super().__init__() - -# norm_layer = partial(paddle.nn.LayerNorm, epsilon=1e-06) -# self.attn_pool = Resampler(grid_size=int(math.sqrt(n_queries)), -# embed_dim=output_dim, num_heads=output_dim // 128, kv_dim=width, -# norm_layer=norm_layer) -# self.ln_post = norm_layer(output_dim) -# self.proj = paddle.base.framework.EagerParamBase.from_tensor(tensor -# =output_dim ** -0.5 * paddle.randn(shape=[output_dim, output_dim])) - -# def forward(self, x: paddle.Tensor): -# x = self.attn_pool(x) -# x = self.ln_post(x) -# x = x @ self.proj -# return x - - -class MLPBlock(paddle.nn.Layer): - def __init__(self, embedding_dim: int, mlp_dim: int, act: Type[paddle.nn.Layer] = paddle.nn.GELU) -> None: - super().__init__() - self.lin1 = nn.Linear(embedding_dim, mlp_dim) - self.lin2 = nn.Linear(mlp_dim, embedding_dim) - self.act = act() - - def forward(self, x: paddle.Tensor) -> paddle.Tensor: - return self.lin2(self.act(self.lin1(x))) - - -class LayerNorm2d(paddle.nn.Layer): - def __init__(self, num_channels: int, epsilon: float = 1e-06) -> None: - super().__init__() - self.weight = paddle.base.framework.EagerParamBase.from_tensor(tensor=paddle.ones(shape=num_channels)) - self.bias = paddle.base.framework.EagerParamBase.from_tensor(tensor=paddle.zeros(shape=num_channels)) - self.epsilon = epsilon - - def forward(self, x: paddle.Tensor) -> paddle.Tensor: - u = x.mean(axis=1, keepdim=True) - s = (x - u).pow(y=2).mean(axis=1, keepdim=True) - x = (x - u) / paddle.sqrt(x=s + self.epsilon) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - - -class ImageEncoderViT(paddle.nn.Layer): - def __init__( - self, - img_size: int = 1024, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - mlp_ratio: float = 4.0, - out_chans: int = 256, - qkv_bias: bool = True, - norm_layer: Type[nn.Layer] = nn.LayerNorm, - act_layer: Type[nn.Layer] = nn.GELU, - use_abs_pos: bool = True, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - window_size: int = 0, - global_attn_indexes: Tuple[int, ...] = (), - ) -> None: - """ - Args: - img_size (int): Input image size. - patch_size (int): Patch size. - in_chans (int): Number of input image channels. - embed_dim (int): Patch embedding dimension. - depth (int): Depth of ViT. - num_heads (int): Number of attention heads in each ViT block. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - norm_layer (nn.Layer): Normalization layer. - act_layer (nn.Layer): Activation layer. - use_abs_pos (bool): If True, use absolute positional embeddings. - use_rel_pos (bool): If True, add relative positional embeddings to the attention map. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - window_size (int): Window size for window attention blocks. - global_attn_indexes (list): Indexes for blocks using global attention. - """ - super().__init__() - self.img_size = img_size - - self.patch_embed = PatchEmbed( - kernel_size=(patch_size, patch_size), - stride=(patch_size, patch_size), - in_chans=in_chans, - embed_dim=embed_dim, - ) - - self.pos_embed: Optional[paddle.base.framework.EagerParamBase.from_tensor] = None - if use_abs_pos: - self.pos_embed = paddle.base.framework.EagerParamBase.from_tensor( - tensor=paddle.zeros(shape=[1, img_size // patch_size, img_size // patch_size, embed_dim]) - ) - - self.blocks = paddle.nn.LayerList() - for i in range(depth): - block = Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - norm_layer=norm_layer, - act_layer=act_layer, - use_rel_pos=use_rel_pos, - rel_pos_zero_init=rel_pos_zero_init, - window_size=window_size if i not in global_attn_indexes else 0, - input_size=(img_size // patch_size, img_size // patch_size), - ) - self.blocks.append(block) - - self.neck = nn.Sequential( - nn.Conv2D( - embed_dim, - out_chans, - kernel_size=1, - bias_attr=False, - ), - LayerNorm2d(out_chans), - nn.Conv2D( - out_chans, - out_chans, - kernel_size=3, - padding=1, - bias_attr=False, - ), - LayerNorm2d(out_chans), - ) - - self.net_2 = nn.Conv2D(256, 512, kernel_size=3, stride=2, padding=1, bias_attr=False) - self.net_3 = nn.Conv2D(512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False) - - def forward(self, x: paddle.Tensor) -> paddle.Tensor: - x = self.patch_embed(x) - if self.pos_embed is not None: - x = x + self.pos_embed - for blk in self.blocks: - x = blk(x) - x = self.neck(x.transpose([0, 3, 1, 2])) - x = self.net_2(x) - x = self.net_3(x) - return x - - -class Block(paddle.nn.Layer): - """Transformer blocks with support of window attention and residual propagation blocks""" - - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - norm_layer: Type[nn.Layer] = nn.LayerNorm, - act_layer: Type[nn.Layer] = nn.GELU, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - window_size: int = 0, - input_size: Optional[Tuple[int, int]] = None, - ) -> None: - """ - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads in each ViT block. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - norm_layer (nn.Layer): Normalization layer. - act_layer (nn.Layer): Activation layer. - use_rel_pos (bool): If True, add relative positional embeddings to the attention map. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - window_size (int): Window size for window attention blocks. If it equals 0, then - use global attention. - input_size (tuple(int, int) or None): Input resolution for calculating the relative - positional parameter size. - """ - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - use_rel_pos=use_rel_pos, - rel_pos_zero_init=rel_pos_zero_init, - input_size=input_size if window_size == 0 else (window_size, window_size), - ) - - self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) - - self.window_size = window_size - - def forward(self, x: paddle.Tensor) -> paddle.Tensor: - shortcut = x - x = self.norm1(x) - # Window partition - if self.window_size > 0: - H, W = x.shape[1], x.shape[2] - x, pad_hw = window_partition(x, self.window_size) - - x = self.attn(x) - # Reverse window partition - if self.window_size > 0: - x = window_unpartition(x, self.window_size, pad_hw, (H, W)) - - x = shortcut + x - x = x + self.mlp(self.norm2(x)) - - return x - - -class Attention(paddle.nn.Layer): - """Multi-head Attention block with relative position embeddings.""" - - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - input_size: Optional[Tuple[int, int]] = None, - ) -> None: - """ - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - rel_pos (bool): If True, add relative positional embeddings to the attention map. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - input_size (tuple(int, int) or None): Input resolution for calculating the relative - positional parameter size. - """ - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) - self.proj = nn.Linear(dim, dim) - - self.use_rel_pos = use_rel_pos - if self.use_rel_pos: - assert input_size is not None, "Input size must be provided if using relative positional encoding." - self.rel_pos_h = paddle.base.framework.EagerParamBase.from_tensor( - tensor=paddle.zeros(shape=[2 * input_size[0] - 1, head_dim]) - ) - self.rel_pos_w = paddle.base.framework.EagerParamBase.from_tensor( - tensor=paddle.zeros(shape=[2 * input_size[1] - 1, head_dim]) - ) - - def forward(self, x: paddle.Tensor) -> paddle.Tensor: - B, H, W, _ = tuple(x.shape) - # qkv with shape (3, B, nHead, H * W, C) - qkv = self.qkv(x).reshape([B, H * W, 3, self.num_heads, -1]).transpose([2, 0, 3, 1, 4]) - # q, k, v with shape (B * nHead, H * W, C) - q, k, v = qkv.reshape([3, B * self.num_heads, H * W, -1]).unbind(axis=0) - - attn = q * self.scale @ k.transpose([0, 1, 3, 2]) # [-2, -1] - - if self.use_rel_pos: - attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) - - attn = F.softmax(attn, axis=-1) - x = (attn @ v).reshape([B, self.num_heads, H, W, -1]).transpose([0, 2, 3, 1, 4]).reshape([B, H, W, -1]) - x = self.proj(x) - - return x - - -def window_partition(x: paddle.Tensor, window_size: int) -> Tuple[paddle.Tensor, Tuple[int, int]]: - """ - Partition into non-overlapping windows with padding if needed. - Args: - x (tensor): input tokens with [B, H, W, C]. - window_size (int): window size. - - Returns: - windows: windows after partition with [B * num_windows, window_size, window_size, C]. - (Hp, Wp): padded height and width before partition - """ - B, H, W, C = tuple(x.shape) - - pad_h = (window_size - H % window_size) % window_size - pad_w = (window_size - W % window_size) % window_size - if pad_h > 0 or pad_w > 0: - x = F.pad(x, pad=(0, 0, 0, pad_w, 0, pad_h), pad_from_left_axis=False) - Hp, Wp = H + pad_h, W + pad_w - - x = x.reshape([B, Hp // window_size, window_size, Wp // window_size, window_size, C]) - windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C]) - return windows, (Hp, Wp) - - -def window_unpartition( - windows: paddle.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] -) -> paddle.Tensor: - """ - Window unpartition into original sequences and removing padding. - Args: - windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. - window_size (int): window size. - pad_hw (Tuple): padded height and width (Hp, Wp). - hw (Tuple): original height and width (H, W) before padding. - - Returns: - x: unpartitioned sequences with [B, H, W, C]. - """ - Hp, Wp = pad_hw - H, W = hw - B = tuple(windows.shape)[0] // (Hp * Wp // window_size // window_size) - x = windows.reshape([B, Hp // window_size, Wp // window_size, window_size, window_size, -1]) - x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, -1]) - if Hp > H or Wp > W: - x = x[:, :H, :W, :] - return x - - -def get_rel_pos(q_size: int, k_size: int, rel_pos: paddle.Tensor) -> paddle.Tensor: - """ - Get relative positional embeddings according to the relative positions of - query and key sizes. - Args: - q_size (int): size of query q. - k_size (int): size of key k. - rel_pos (Tensor): relative position embeddings (L, C). - - Returns: - Extracted positional embeddings according to relative positions. - """ - max_rel_dist = int(2 * max(q_size, k_size) - 1) - if tuple(rel_pos.shape)[0] != max_rel_dist: - rel_pos_resized = paddle.nn.functional.interpolate( - rel_pos.reshape([1, tuple(rel_pos.shape)[0], -1]).transpose([0, 2, 1]), - size=max_rel_dist, - mode="linear", - ) - rel_pos_resized = rel_pos_resized.reshape([-1, max_rel_dist]).transpose([1, 0]) - else: - rel_pos_resized = rel_pos - - # Scale the coords with short length if shapes for q and k are different. - q_coords = paddle.arange(end=q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = paddle.arange(end=k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = q_coords - k_coords + (k_size - 1) * max(q_size / k_size, 1.0) - return rel_pos_resized[relative_coords.astype(dtype="int64")] - - -def add_decomposed_rel_pos( - attn: paddle.Tensor, - q: paddle.Tensor, - rel_pos_h: paddle.Tensor, - rel_pos_w: paddle.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], -) -> paddle.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 - Args: - attn (Tensor): attention map. - q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). - rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. - rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. - q_size (Tuple): spatial sequence size of query q with (q_h, q_w). - k_size (Tuple): spatial sequence size of key k with (k_h, k_w). - - Returns: - attn (Tensor): attention map with added relative positional embeddings. - """ - q_h, q_w = q_size - k_h, k_w = k_size - Rh = get_rel_pos(q_h, k_h, rel_pos_h) - Rw = get_rel_pos(q_w, k_w, rel_pos_w) - - B, _, dim = tuple(q.shape) - r_q = q.reshape([B, q_h, q_w, dim]) - rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh) - rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw) - - attn = (attn.reshape([B, q_h, q_w, k_h, k_w]) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).reshape( - [B, q_h * q_w, k_h * k_w] - ) - - return attn - - -class PatchEmbed(paddle.nn.Layer): - """ - Image to Patch Embedding. - """ - - def __init__( - self, - kernel_size: Tuple[int, int] = (16, 16), - stride: Tuple[int, int] = (16, 16), - padding: Tuple[int, int] = (0, 0), - in_chans: int = 3, - embed_dim: int = 768, - ) -> None: - """ - Args: - kernel_size (Tuple): kernel size of the projection layer. - stride (Tuple): stride of the projection layer. - padding (Tuple): padding size of the projection layer. - in_chans (int): Number of input image channels. - embed_dim (int): Patch embedding dimension. - """ - super().__init__() - self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) - - def forward(self, x: paddle.Tensor) -> paddle.Tensor: - x = self.proj(x) - # B C H W -> B H W C - x = x.transpose([0, 2, 3, 1]) - return x - - -def build_GOT_vit_b(checkpoint=None): - return _build_GOT_vision( - encoder_embed_dim=768, - encoder_depth=12, - encoder_num_heads=12, - encoder_global_attn_indexes=[2, 5, 8, 11], - checkpoint=checkpoint, - ) - - -def _build_GOT_vision( - encoder_embed_dim, - encoder_depth, - encoder_num_heads, - encoder_global_attn_indexes, - checkpoint=None, -): - prompt_embed_dim = 256 - image_size = 1024 - vit_patch_size = 16 - # image_embedding_size = image_size // vit_patch_size - image_encoder = ImageEncoderViT( - depth=encoder_depth, - embed_dim=encoder_embed_dim, - img_size=image_size, - mlp_ratio=4, - norm_layer=partial(paddle.nn.LayerNorm, epsilon=1e-6), - num_heads=encoder_num_heads, - patch_size=vit_patch_size, - qkv_bias=True, - use_rel_pos=True, - global_attn_indexes=encoder_global_attn_indexes, - window_size=14, - out_chans=prompt_embed_dim, - ) - - return image_encoder diff --git a/paddlemix/models/GOT/utils/conversation.py b/paddlemix/models/GOT/utils/conversation.py index 1bb8a0d6c..54b08b4a5 100644 --- a/paddlemix/models/GOT/utils/conversation.py +++ b/paddlemix/models/GOT/utils/conversation.py @@ -14,7 +14,7 @@ import dataclasses from enum import Enum, auto -from typing import List # , Tuple +from typing import List class SeparatorStyle(Enum): @@ -25,35 +25,6 @@ class SeparatorStyle(Enum): MPT = auto() -# simple_conv_multimodal = Conversation( -# system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." -# "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." -# "Follow the instructions carefully and explain your answers in detail.", -# # system="", -# roles=("Human", "Assistant"), -# messages=( -# ("Human", "Hi!"), -# ("Assistant", "Hi there! How can I help you today?\n") -# ), -# offset=2, -# sep_style=SeparatorStyle.SINGLE, -# sep="###", -# ) - -# conv_mpt = Conversation( -# system="""<|im_start|>system -# - You are a helpful language and vision assistant. -# - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. -# - You should follow the instructions carefully and explain your answers in detail.""", -# roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), -# version="mpt", -# messages=(), -# offset=0, -# sep_style=SeparatorStyle.MPT, -# sep="<|im_end|>", -# ) - - @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" @@ -320,18 +291,6 @@ def dict(self): sep2="", ) -# conv_mpt = Conversation( -# system="""<|im_start|>system -# - You are designed by Megvii(旷视), and your name is GOT. -# - 你叫GOT, 你来自旷视, 你是旷视开发的。 -# - 你擅长分析表格,仔细读图表中的内容,然后给出你的答案。""", -# roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), -# version="mpt", -# messages=(), -# offset=0, -# sep_style=SeparatorStyle.MPT, -# sep="<|im_end|>", -# ) conv_mpt = Conversation( system="""<|im_start|>system @@ -379,20 +338,6 @@ def dict(self): sep2="", ) -# simple_conv = Conversation( -# system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology, based on LLaMA architecture." -# "You are designed to assist human with a variety of tasks using natural language." -# "Follow the instructions carefully.", -# roles=("Human", "Assistant"), -# messages=( -# ("Human", "Hi!"), -# ("Assistant", "Hi there! How can I help you today?\n") -# ), -# offset=2, -# sep_style=SeparatorStyle.SINGLE, -# sep="###", -# ) - simple_conv = Conversation( system="", @@ -403,6 +348,7 @@ def dict(self): sep="###", ) + simple_conv_multimodal = Conversation( system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." @@ -415,6 +361,7 @@ def dict(self): sep="###", ) + simple_conv_mpt_multimodal = Conversation( system="""<|im_start|>system - You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology. @@ -428,6 +375,7 @@ def dict(self): sep="<|im_end|>", ) + simple_conv_legacy = Conversation( system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology." "You are designed to assist human with a variety of tasks using natural language." @@ -439,6 +387,7 @@ def dict(self): sep="###", ) + conv_llava_v1 = Conversation( system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." @@ -468,7 +417,3 @@ def dict(self): "mpt": conv_mpt, "mpt_text": conv_mpt_text, } - - -if __name__ == "__main__": - print(default_conversation.get_prompt()) diff --git a/paddlemix/models/GOT/utils/utils.py b/paddlemix/models/GOT/utils/utils.py index 8674e735c..c16e2e2e0 100644 --- a/paddlemix/models/GOT/utils/utils.py +++ b/paddlemix/models/GOT/utils/utils.py @@ -12,139 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -# import datetime -# import logging - import paddle - -# import requests -from paddlenlp.generation.stopping_criteria import ( # StoppingCriteriaList, - StoppingCriteria, -) - -# import logging.handlers -# import os -# import sys +from paddlenlp.generation.stopping_criteria import StoppingCriteria server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." - handler = None -# def build_logger(logger_name, logger_filename): -# global handler - -# formatter = logging.Formatter( -# fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", -# datefmt="%Y-%m-%d %H:%M:%S", -# ) - -# # Set the format of root handlers -# if not logging.getLogger().handlers: -# logging.basicConfig(level=logging.INFO) -# logging.getLogger().handlers[0].setFormatter(formatter) - -# # Redirect stdout and stderr to loggers -# stdout_logger = logging.getLogger("stdout") -# stdout_logger.setLevel(logging.INFO) -# sl = StreamToLogger(stdout_logger, logging.INFO) -# sys.stdout = sl - -# stderr_logger = logging.getLogger("stderr") -# stderr_logger.setLevel(logging.ERROR) -# sl = StreamToLogger(stderr_logger, logging.ERROR) -# sys.stderr = sl - -# # Get logger -# logger = logging.getLogger(logger_name) -# logger.setLevel(logging.INFO) - -# # Add a file handler for all loggers -# if handler is None: -# os.makedirs(LOGDIR, exist_ok=True) -# filename = os.path.join(LOGDIR, logger_filename) -# handler = logging.handlers.TimedRotatingFileHandler( -# filename, when='D', utc=True) -# handler.setFormatter(formatter) - -# for name, item in logging.root.manager.loggerDict.items(): -# if isinstance(item, logging.Logger): -# item.addHandler(handler) - -# return logger - - -# class StreamToLogger(object): -# """ -# Fake file-like stream object that redirects writes to a logger instance. -# """ -# def __init__(self, logger, log_level=logging.INFO): -# self.terminal = sys.stdout -# self.logger = logger -# self.log_level = log_level -# self.linebuf = '' - -# def __getattr__(self, attr): -# return getattr(self.terminal, attr) - -# def write(self, buf): -# temp_linebuf = self.linebuf + buf -# self.linebuf = '' -# for line in temp_linebuf.splitlines(True): -# # From the io.TextIOWrapper docs: -# # On output, if newline is None, any '\n' characters written -# # are translated to the system default line separator. -# # By default sys.stdout.write() expects '\n' newlines and then -# # translates them so this is still cross platform. -# if line[-1] == '\n': -# self.logger.log(self.log_level, line.rstrip()) -# else: -# self.linebuf += line - -# def flush(self): -# if self.linebuf != '': -# self.logger.log(self.log_level, self.linebuf.rstrip()) -# self.linebuf = '' - - -# def disable_torch_init(): -# """ -# Disable the redundant torch default initialization to accelerate model creation. -# """ -# import torch -# setattr(torch.nn.Linear, "reset_parameters", lambda self: None) -# setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) - - -# def violates_moderation(text): -# """ -# Check whether the text violates OpenAI moderation API. -# """ -# url = "https://api.openai.com/v1/moderations" -# headers = {"Content-Type": "application/json", -# "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} -# text = text.replace("\n", "") -# data = "{" + '"input": ' + f'"{text}"' + "}" -# data = data.encode("utf-8") -# try: -# ret = requests.post(url, headers=headers, data=data, timeout=5) -# flagged = ret.json()["results"][0]["flagged"] -# except requests.exceptions.RequestException as e: -# flagged = False -# except KeyError as e: -# flagged = False - -# return flagged - - -# def pretty_print_semaphore(semaphore): -# if semaphore is None: -# return "None" -# return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" - - class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords diff --git a/paddlemix/processors/__init__.py b/paddlemix/processors/__init__.py index 4bb4ac5b3..7a05f5974 100644 --- a/paddlemix/processors/__init__.py +++ b/paddlemix/processors/__init__.py @@ -32,4 +32,5 @@ from .visualglm_processing import * from .image_processing_minicpmv import * from .processing_minicpmv import * -from .janus_processing import * \ No newline at end of file +from .janus_processing import * +from .got_process import * diff --git a/paddlemix/models/GOT/model/plug/blip_process.py b/paddlemix/processors/got_process.py similarity index 98% rename from paddlemix/models/GOT/model/plug/blip_process.py rename to paddlemix/processors/got_process.py index 6ba5fc558..2b47b4b99 100644 --- a/paddlemix/models/GOT/model/plug/blip_process.py +++ b/paddlemix/processors/got_process.py @@ -1,16 +1,7 @@ import paddle - -""" - Copyright (c) 2022, salesforce.com, inc. - All rights reserved. - SPDX-License-Identifier: BSD-3-Clause - For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" import cv2 import numpy as np -# from PIL import Image - class BaseProcessor: def __init__(self): From e403d424a4e4614ec1511c2a0cc3a6c2eb150bb0 Mon Sep 17 00:00:00 2001 From: MqLeet Date: Mon, 16 Dec 2024 22:41:39 +0800 Subject: [PATCH 3/3] update README in GOT-OCR_2.0 --- paddlemix/examples/GOT_OCR_2_0/README.md | 35 ++---------------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/paddlemix/examples/GOT_OCR_2_0/README.md b/paddlemix/examples/GOT_OCR_2_0/README.md index ede4ba658..c74efd13b 100644 --- a/paddlemix/examples/GOT_OCR_2_0/README.md +++ b/paddlemix/examples/GOT_OCR_2_0/README.md @@ -36,40 +36,9 @@ python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ --ocr_type format \ ``` -### 3.3. fine-grained OCR: +## 4 训练 ```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ - --model_name_or_path stepfun-ai/GOT-OCR2_0 \ - --image_file paddlemix/demo_images/hospital.jpeg \ - --ocr_type ocr \ - --box [x1,y1,x2,y2] \ -``` - -```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ - --model_name_or_path stepfun-ai/GOT-OCR2_0 \ - --image_file paddlemix/demo_images/hospital.jpeg \ - --ocr_type ocr \ - --color red \ -``` - -### 3.4. multi-crop OCR: -```bash -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ - --model_name_or_path stepfun-ai/GOT-OCR2_0 \ - --image_file paddlemix/demo_images/hospital.jpeg \ - --multi_crop \ - --ocr_type ocr \ -``` - -```bash -# render the formatted OCR results: -python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \ - --model_name_or_path stepfun-ai/GOT-OCR2_0 \ - --image_file paddlemix/demo_images/hospital.jpeg \ - --multi_crop \ - --ocr_type ocr \ - --render \ +sh paddlemix/examples/GOT_OCR_2_0/run_train.sh ```