Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#100 from LokeZhou/minigpt4
Browse files Browse the repository at this point in the history
[Appflow] add Minigpt4
  • Loading branch information
LokeZhou authored Aug 24, 2023
2 parents 7f67db4 + d504124 commit 700d04a
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 8 deletions.
3 changes: 2 additions & 1 deletion applications/Automatic_label/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ task = Appflow(app="auto_label",
)
url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations.png"
image_pil = load_image(url)
result = task(image=image_pil)
blip2_prompt = 'describe the image'
result = task(image=image_pil,blip2_prompt = blip2_prompt)
```

效果展示
Expand Down
67 changes: 67 additions & 0 deletions applications/image2text/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@


### 图文生成(Image-to-Text Generation)

## miniGPT4
使用miniGPT4前,需要下载相应权重进行转换,具体可参考[miniGPT4](../../paddlemix/examples/minigpt4/README.md),在完成权重转换后,根据模型权重文件以及配置文件按下存放:
```bash
--PPMIX_HOME #默认路径 /root/.paddlemix 可通过export PPMIX_HOME 设置
--models
--miniGPT4
--MiniGPT4-7B
config.json
model_state.pdparams
special_tokens_map.json
image_preprocessor_config.json
preprocessor_config.json
tokenizer_config.json
model_config.json
sentencepiece.bpe.model
tokenizer.json
--MiniGPT4-13B
...
...
...

```
完成之后,可使用appflow 一键预测
```python
from paddlemix import Appflow
import requests

task = Appflow(app="image2text_generation",
models=["miniGPT4/MiniGPT4-7B"])
url = "https://paddlenlp.bj.bcebos.com/data/images/mugs.png"
image = Image.open(requests.get(url, stream=True).raw)
minigpt4_text = "describe the image"
result = task(image=image,minigpt4_text=minigpt4_text)
```

效果展示

<div align="center">

| Image | text | Generated text|
|:----:|:----:|:----:|
|![mugs](https://github.com/LokeZhou/PaddleMIX/assets/13300429/b5a95002-bb30-4683-8e62-ed21879f24e1) | describe the image|The image shows two mugs with cats on them, one is black and white and the other is blue and white. The mugs are sitting on a table with a book in the background. The mugs have a whimsical, cartoon-like appearance. The cats on the mugs are looking at each other with a playful expression. The overall style of the image is cute and fun.###|
</div>

## blip2

```python
from paddlemix import Appflow
from ppdiffusers.utils import load_image

task = Appflow(app="image2text_generation",
models=["paddlemix/blip2-caption-opt2.7b"])
url = "https://paddlenlp.bj.bcebos.com/data/images/mugs.png"
image_pil = load_image(url)
blip2_prompt = 'describe the image'
result = task(image=image_pil,blip2_prompt=blip2_prompt)
```

| Image | text | Generated text|
|:----:|:----:|:----:|
|![mugs](https://github.com/LokeZhou/PaddleMIX/assets/13300429/b5a95002-bb30-4683-8e62-ed21879f24e1) | describe the image|of the two coffee mugs with cats on them|
</div>

17 changes: 16 additions & 1 deletion paddlemix/appflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
StableDiffusionImg2ImgTask,
StableDiffusionUpscaleTask,
)
from .image2text_generation import Blip2CaptionTask
from .image2text_generation import Blip2CaptionTask, MiniGPT4Task
from .openset_det_sam import OpenSetDetTask, OpenSetSegTask
from .text2image_generation import StableDiffusionTask, VersatileDiffusionDualGuidedTask
from .text2image_inpaiting import StableDiffusionInpaintTask
Expand Down Expand Up @@ -137,4 +137,19 @@
"model": "damo-vilab/text-to-video-ms-1.7b",
},
},
"image2text_generation": {
"models": {
"paddlemix/blip2-caption-opt2.7b": {
"task_class": Blip2CaptionTask,
"task_flag": "autolabel_blip2-caption-opt2.7b",
},
"miniGPT4/MiniGPT4-7B": {
"task_class": MiniGPT4Task,
"task_flag": "image2text_generation-MiniGPT4-7B",
},
},
"default": {
"model": "paddlemix/blip2-caption-opt2.7b",
},
},
}
93 changes: 87 additions & 6 deletions paddlemix/appflow/image2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import nltk
from paddlenlp.transformers import AutoTokenizer

from paddlemix.models import MiniGPT4ForConditionalGeneration
from paddlemix.models.blip2.modeling import Blip2ForConditionalGeneration
from paddlemix.processors import MiniGPT4Processor
from paddlemix.processors.blip_processing import (
Blip2Processor,
BlipImageProcessor,
Expand Down Expand Up @@ -64,8 +66,8 @@ def _preprocess(self, inputs):
""" """
image = inputs.get("image", None)
assert image is not None, "The image is None"

prompt = "describe the image"
prompt = inputs.get("blip2_prompt", None)
assert image is not None, "The blip2_prompt is None"

blip2_input = self._processor(
images=image,
Expand Down Expand Up @@ -99,8 +101,7 @@ def _postprocess(self, inputs):
generated_text = self._processor.batch_decode(inputs["result"], skip_special_tokens=True)[0].strip()
logger.info("Generate text: {}".format(generated_text))

inputs.pop("result", None)

inputs["result"] = generated_text
inputs["prompt"] = self._generate_tags(generated_text)

return inputs
Expand All @@ -111,6 +112,86 @@ def _generate_tags(self, caption):
nltk.download(["punkt", "averaged_perceptron_tagger", "wordnet"])
tags_list = [word for (word, pos) in nltk.pos_tag(nltk.word_tokenize(caption)) if pos[0] == "N"]
tags_lemma = [lemma.lemmatize(w) for w in tags_list]
tags = ", ".join(map(str, tags_lemma))
tags = ",".join(map(str, tags_lemma))
tags = set(tags.split(","))
new_tags = ",".join(tags)
return new_tags


class MiniGPT4Task(AppTask):
def __init__(self, task, model, **kwargs):
super().__init__(task=task, model=model, **kwargs)

self._generate_kwargs = {
"max_length": 300,
"num_beams": 1,
"top_p": 1.0,
"top_k": 0,
"repetition_penalty": 1.0,
"length_penalty": 0.0,
"temperature": 1.0,
"decode_strategy": "greedy_search",
"eos_token_id": [[835], [2277, 29937]],
}
# Default to static mode
self._static_mode = False

self._construct_processor(model)
self._construct_model(model)

def _construct_processor(self, model):
"""
Construct the tokenizer for the predictor.
"""

self._processor = MiniGPT4Processor.from_pretrained(model)

def _construct_model(self, model):
"""
Construct the inference model for the predictor.
"""
# bulid model
model_instance = MiniGPT4ForConditionalGeneration.from_pretrained(self._task_path)

self._model = model_instance
self._model.eval()

return tags
def _preprocess(self, inputs):
""" """
image = inputs.get("image", None)
assert image is not None, "The image is None"
minigpt4_text = inputs.get("minigpt4_text", None)
assert minigpt4_text is not None, "The minigpt4_text is None"

prompt = "Give the following image: <Img>ImageContent</Img>. You will be able to see the image once I provide it to you. Please answer my questions.###Human: <Img><ImageHere></Img> <TextHere>###Assistant:"
minigpt4_input = self._processor([image], minigpt4_text, prompt)

inputs.pop("minigpt4_text", None)
inputs["minigpt4_input"] = minigpt4_input

return inputs

def _run_model(self, inputs):
"""
Run the task model from the outputs of the `_preprocess` function.
"""
generate_kwargs = inputs.get("generate_kwargs", None)
generate_kwargs = self._generate_kwargs if generate_kwargs is None else generate_kwargs
outputs = self._model.generate(**inputs["minigpt4_input"], **generate_kwargs)

inputs.pop("minigpt4_input", None)

inputs["result"] = outputs

return inputs

def _postprocess(self, inputs):
"""
The model output is tag ids, this function will convert the model output to raw text.
"""
generated_text = self._processor.batch_decode(inputs["result"][0])[0]
logger.info("Generate text: {}".format(generated_text))

inputs["result"] = generated_text

return inputs

0 comments on commit 700d04a

Please sign in to comment.