Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added the generated_caption_mapper #131

Merged
merged 12 commits into from
Jan 11, 2024
4 changes: 4 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ process:
- clean_copyright_mapper: # remove copyright comments.
- expand_macro_mapper: # expand macro definitions in Latex text.
- fix_unicode_mapper: # fix unicode errors in text.
- generate_caption_mapper:
hf_blip2: 'Salesforce/blip2-opt-2.7b' # blip2 model name on huggingface to generate caption
caption_num: 1 # how many candidate captions to generate for each image
keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"].
- nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library
sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently.
aug_num: 1 # number of augmented samples to be generated. If `sequential` is True, there will be total aug_num augmented samples generated. If it's False, there will be (aug_num * #opened_aug_method) augmented samples generated.
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from . import (chinese_convert_mapper, clean_copyright_mapper,
clean_email_mapper, clean_html_mapper, clean_ip_mapper,
clean_links_mapper, expand_macro_mapper, fix_unicode_mapper,
nlpaug_en_mapper, nlpcda_zh_mapper,
generate_caption_mapper, nlpaug_en_mapper, nlpcda_zh_mapper,
punctuation_normalization_mapper, remove_bibliography_mapper,
remove_comments_mapper, remove_header_mapper,
remove_long_words_mapper, remove_non_chinese_character_mapper,
Expand Down
258 changes: 258 additions & 0 deletions data_juicer/ops/mapper/generate_caption_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import copy
import random

import numpy as np
from jsonargparse.typing import PositiveInt

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import HashKeys
from data_juicer.utils.mm_utils import (SpecialTokens,
insert_texts_after_placeholders,
load_image, remove_non_special_tokens,
remove_special_tokens)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper
from ..deduplicator.document_simhash_deduplicator import (
DocumentSimhashDeduplicator, num_differing_bits)
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'generate_caption_mapper'

with AvailabilityChecking(['torch', 'transformers'], OP_NAME):
import torch
import transformers # noqa: F401

# avoid hanging when calling blip2 in multiprocessing
torch.set_num_threads(1)


@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class GenerateCaptionMapper(Mapper):
yxdyc marked this conversation as resolved.
Show resolved Hide resolved
"""Mapper to generate samples whose captions are generated based on
another model and the figure."""

def __init__(self,
hf_blip2='Salesforce/blip2-opt-2.7b',
caption_num: PositiveInt = 1,
keep_candidate_mode: str = 'random_any',
*args,
**kwargs):
"""
Initialization method.

:param hf_blip2: blip2 model name on huggingface to generate caption
:param caption_num: how many candidate captions to generate
for each image
:param keep_candidate_mode: retain strategy for the generated
$caption_num$ candidates.
'random_any': Retain the random one from generated captions
'similar_one_simhash': Retain the generated one that is most
similar to the original caption
'all': Retain all generated captions by concatenation
Note: This is a batched_OP, whose input and output type are
both list. Suppose there are $N$ list of input samples, whose batch
size is $b$, and denote caption_num as $M$.
The number of total samples after generation is $2Nb$
for 'random_any' and 'similar_one_simhash' mode,
and $(1+M)Nb$ for 'all' mode.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self._batched_op = True
if keep_candidate_mode not in [
'random_any', 'similar_one_simhash', 'all'
]:
raise ValueError(
f'Keep strategy [{keep_candidate_mode}] is not supported. '
f'Can only be one of '
f'["random_any", "similar_one_simhash", "all"].')

self.model_key = prepare_model(model_type='hf_blip',
model_key=hf_blip2,
usage='conditional_generation')
model, img_processor = get_model(model_key=self.model_key,
usage='conditional_generation')
self.model_in_ctx = model
self.img_processor_in_ctx = img_processor
self.caption_num = caption_num
self.keep_candidate_mode = keep_candidate_mode
self.extra_args = kwargs

if keep_candidate_mode in ['random_any', 'similar_one_simhash']:
self.num_newly_generated_samples = 1
elif keep_candidate_mode in ['all']:
self.num_newly_generated_samples = self.caption_num
else:
self.num_newly_generated_samples = 0

def _process_single_sample(self, ori_sample):
"""

:param ori_sample: a single data sample before applying generation
:return: batched results after generation
"""
# there is no image in this sample
if self.image_key not in ori_sample or \
not ori_sample[self.image_key]:
return []

# the generated results
generated_samples = [
copy.deepcopy(ori_sample)
for _ in range(self.num_newly_generated_samples)
]
for generated_sample in generated_samples:
generated_sample[self.text_key] = ''

# 1. load all image(s)
loaded_image_keys = ori_sample[self.image_key]
images = {}
for loaded_image_key in loaded_image_keys:
if loaded_image_key not in images:
# avoid loading the same images
image = load_image(loaded_image_key)
images[loaded_image_key] = image

offset = 0

# we follow such assumption:
# all text/img/video/audio data within a chunk are correlated.
# As a result,
# the original text will be removed,
# the generated text will be placed following each SpecialTokens.img
# and the original special tokens are kept in an order-preserving way.

# do generation for each image chunk by chunk
for chunk in ori_sample[self.text_key].split(SpecialTokens.eoc):
# skip empty chunks or contents after the last eoc token
if not chunk.strip():
continue

img_count = chunk.count(SpecialTokens.image)
text_with_only_special_tokens = remove_non_special_tokens(chunk)
image_chunk = []
for image_key in loaded_image_keys[offset:offset + img_count]:
image = images[image_key]
image_chunk.append(image)

# 2. generate candidate caption(s) in batch manner
generated_text_candidates_single_chunk = \
[[] for _ in range(self.caption_num)]
# an assistant 2-D array,
# generated_text_candidates_single_chunk[i][j] indicates
# the $i$-th generated candidate for the $j$-th image

inputs = self.img_processor_in_ctx(images=image_chunk,
return_tensors='pt')
for i in range(self.caption_num):
generated_ids = self.model_in_ctx.generate(**inputs,
do_sample=True)
generated_text = self.img_processor_in_ctx.batch_decode(
generated_ids, skip_special_tokens=True)
generated_text_candidates_single_chunk[i] = generated_text

# 3. insert a list of generated captions into the positions of
# subsequent placeholders in the original string
new_generated_text_all_images = \
[[] for _ in range(self.num_newly_generated_samples)]
# new_generated_text_all_images is a helper array, element [i][j]
# denotes the reduced $i$-th result for the $j$-th image

# reduce the captions according to given mode image by image
for j in range(img_count):
new_generated_text_per_image = self._reduce_captions_per_image(
chunk, [
captions[j]
for captions in generated_text_candidates_single_chunk
])
assert self.num_newly_generated_samples == \
len(new_generated_text_per_image)
for i in range(len(new_generated_text_per_image)):
new_generated_text_all_images[i].append(
new_generated_text_per_image[i])

# insert the captions according to given mode
place_holders = [SpecialTokens.image] * img_count
for i in range(self.num_newly_generated_samples):
new_generated_text_per_chunk = insert_texts_after_placeholders(
original_string=text_with_only_special_tokens,
placeholders=place_holders,
new_texts=new_generated_text_all_images[i])
generated_samples[i][self.text_key] += \
f'{new_generated_text_per_chunk}{SpecialTokens.eoc}'

offset += img_count

return generated_samples

def _reduce_captions_per_image(self, chunk,
generated_text_candidates_single_chunk):
new_generated_text_per_chunk = []
if self.keep_candidate_mode == 'random_any':
yxdyc marked this conversation as resolved.
Show resolved Hide resolved
new_generated_text_per_chunk.append(
random.choice(generated_text_candidates_single_chunk))
elif self.keep_candidate_mode == 'all':
new_generated_text_per_chunk.extend(
generated_text_candidates_single_chunk)
elif self.keep_candidate_mode == 'similar_one_simhash':
ori_normal_text = remove_special_tokens(chunk)
# using a simhash OP to calculate their similarity
# NOTE: simhash is just one method to calculate the similarities
# between texts, but not the most accurate one. More methods (e.g.
# embedding-based, ...) will be added.
op_simhash = DocumentSimhashDeduplicator(window_size=2,
**self.extra_args)
ori_text_hash = np.uint64(
op_simhash.compute_hash({op_simhash.text_key:
ori_normal_text})[HashKeys.simhash])
generated_text_hashes = [
np.uint64(
op_simhash.compute_hash(
{op_simhash.text_key:
candidate_text})[HashKeys.simhash])
for candidate_text in generated_text_candidates_single_chunk
]
hamming_distances = [
num_differing_bits(ori_text_hash, generated_text_hash)
for generated_text_hash in generated_text_hashes
]
max_index = min(range(len(hamming_distances)),
key=hamming_distances.__getitem__)
new_generated_text_per_chunk.append(
generated_text_candidates_single_chunk[max_index])
return new_generated_text_per_chunk

def process(self, samples):
"""
Note: This is a batched_OP, whose the input and output type are
both list. Suppose there are $N$ input sample list with batch
size as $b$, and denote caption_num as $M$.
the number of total samples after generation is $2Nb$
for 'random_any' and 'similar_one' mode,
and $(1+M)Nb$ for 'all' mode.
:param samples:
:return:
"""
# reconstruct samples from "dict of lists" to "list of dicts"
reconstructed_samples = []
for i in range(len(samples[self.text_key])):
reconstructed_samples.append(
{key: samples[key][i]
for key in samples})
samples_after_generation = []
# do generation for each sample within the batch
for ori_sample in reconstructed_samples:
samples_after_generation.append(ori_sample)
generated_samples = self._process_single_sample(ori_sample)
if len(generated_samples) != 0:
samples_after_generation.extend(generated_samples)
# reconstruct samples from "list of dicts" to "dict of lists"
keys = samples_after_generation[0].keys()
res_samples = {}
for key in keys:
res_samples[key] = [s[key] for s in samples_after_generation]

return res_samples
37 changes: 37 additions & 0 deletions data_juicer/utils/mm_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import numpy as np
from datasets import Audio, Image

Expand Down Expand Up @@ -30,6 +32,15 @@ def remove_special_tokens(text):
return text


def remove_non_special_tokens(text):
special_tokens = get_special_tokens().values()
patterns = '|'.join(re.escape(token) for token in special_tokens)
special_tokens_found = re.findall(patterns, text)
text_with_only_special_tokens = ''.join(special_tokens_found)

return text_with_only_special_tokens


# Image
def load_images(paths):
return [load_image(path) for path in paths]
Expand Down Expand Up @@ -118,3 +129,29 @@ def size_to_bytes(size):
f'expected in [KB, MB, GB, TB, PB, EB, ZB, YB, '
f'KiB, MiB, GiB, TiB, PiB, EiB, ZiB, YiB], '
f'(case insensitive, counted by *Bytes*).')


def insert_texts_after_placeholders(original_string,
placeholders,
new_texts,
delimiter_in_insert_pos=' '):
if len(placeholders) != len(new_texts):
raise ValueError(
'The number of placeholders and new_texts must be equal')

modified_string = original_string
for placeholder, new_text in zip(placeholders, new_texts):
# Find the index of the next occurrence of the placeholder
index = modified_string.find(placeholder)
if index == -1:
raise ValueError(
f"Placeholder '{placeholder}' not found in the string")
# Insert new_text at the found index position
modified_string = \
modified_string[:index + len(placeholder)] + \
HYLcool marked this conversation as resolved.
Show resolved Hide resolved
delimiter_in_insert_pos + \
new_text + \
delimiter_in_insert_pos + \
modified_string[index + len(placeholder):]

return modified_string
Loading