Skip to content

Commit

Permalink
style: change param names
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijianma committed Nov 15, 2023
1 parent 74a094e commit c84d1fa
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ process:
max_ratio: 0.5 # the max ratio of filter range
- clip_similarity_filter: # filter samples according to the similarity between text and images.
hf_clip: openai/clip-vit-base-patch32 # name of used Hugging Face clip
min_ratio: 0.24 # the min similarity of filter range
max_ratio: 1.0 # the max similarity of filter range
min_score: 0.1 # the min similarity of filter range
max_socre: 1.0 # the max similarity of filter range
reduce_mode: avg # reduce mode when one text corresponds to multiple images in a chunk, must be one of ['avg','max', 'min'].
any_or_all: any # keep this sample when any/all images meet the filter condition
- flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value
Expand Down
18 changes: 9 additions & 9 deletions data_juicer/ops/filter/clip_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import torch
from jsonargparse.typing import PositiveFloat
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import SpecialTokens, load_image
Expand All @@ -21,8 +21,8 @@ class ClipSimilarityFilter(Filter):

def __init__(self,
hf_clip='openai/clip-vit-base-patch32',
min_ratio: PositiveFloat = 0.1,
max_ratio: PositiveFloat = 1.0,
min_score: ClosedUnitInterval = 0.1,
max_score: ClosedUnitInterval = 1.0,
any_or_all: str = 'any',
reduce_mode: str = 'avg',
*args,
Expand All @@ -32,8 +32,8 @@ def __init__(self,
:param hf_clip: clip model name on huggingface to compute
the similarity between image and text.
:param min_ratio: The min similarity to keep samples.
:param max_ratio: The max similarity to keep samples.
:param min_score: The min similarity to keep samples.
:param max_score: The max similarity to keep samples.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
Expand All @@ -47,8 +47,8 @@ def __init__(self,
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_ratio = min_ratio
self.max_ratio = max_ratio
self.min_score = min_score
self.max_score = max_score
if reduce_mode not in ['avg', 'max', 'min']:
raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. '
f'Can only be one of ["avg", "max", "min"].')
Expand Down Expand Up @@ -97,7 +97,7 @@ def compute_stats(self, sample, context=False):
offset = 0

def remove_special_token(text):
for key, value in special_token_dict.items():
for value in special_token_dict.values():
text = text.replace(value, '')
return text

Expand Down Expand Up @@ -147,7 +147,7 @@ def process(self, sample):
return True

keep_bools = np.array([
self.min_ratio <= sim_value <= self.max_ratio
self.min_score <= sim_value <= self.max_score
for sim_value in similarity
])

Expand Down
38 changes: 19 additions & 19 deletions tests/ops/filter/test_clip_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from datasets import Dataset

from data_juicer.ops.filter.clip_similarity_filter import (
ClipSimilarityFilter, SpecialTokens)
from data_juicer.ops.filter.clip_similarity_filter import ClipSimilarityFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import SpecialTokens


class ClipSimilarityFilterTest(unittest.TestCase):
Expand Down Expand Up @@ -50,8 +50,8 @@ def test_no_eoc_special_token(self):
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='avg',
any_or_all='any',
min_ratio=0.2,
max_ratio=0.9)
min_score=0.2,
max_score=0.9)
self._run_filter(dataset, tgt_list, op)

def test_eoc_special_token(self):
Expand All @@ -74,8 +74,8 @@ def test_eoc_special_token(self):
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='avg',
any_or_all='any',
min_ratio=0.2,
max_ratio=0.9)
min_score=0.2,
max_score=0.9)
self._run_filter(dataset, tgt_list, op)

def test_keep_any(self):
Expand All @@ -96,8 +96,8 @@ def test_keep_any(self):
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='avg',
any_or_all='any',
min_ratio=0.2,
max_ratio=0.9)
min_score=0.2,
max_score=0.9)
self._run_filter(dataset, tgt_list, op)

def test_keep_all(self):
Expand All @@ -113,8 +113,8 @@ def test_keep_all(self):
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='avg',
any_or_all='all',
min_ratio=0.2,
max_ratio=0.9)
min_score=0.2,
max_score=0.9)
self._run_filter(dataset, tgt_list, op)

def test_reduce_avg(self):
Expand All @@ -133,8 +133,8 @@ def test_reduce_avg(self):
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='avg',
any_or_all='any',
min_ratio=0.2,
max_ratio=0.9)
min_score=0.2,
max_score=0.9)
self._run_filter(dataset, tgt_list, op)

def test_reduce_max(self):
Expand All @@ -153,8 +153,8 @@ def test_reduce_max(self):
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='max',
any_or_all='any',
min_ratio=0.2,
max_ratio=0.9)
min_score=0.2,
max_score=0.9)
self._run_filter(dataset, tgt_list, op)

def test_reduce_min(self):
Expand All @@ -174,12 +174,12 @@ def test_reduce_min(self):
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='min',
any_or_all='any',
min_ratio=0.1,
max_ratio=0.9)
min_score=0.1,
max_score=0.9)

self._run_filter(dataset, tgt_list, op)

op.min_ratio = 0.2
op.min_score = 0.2
self._run_filter(dataset, [], op)

def test_multi_process(self):
Expand All @@ -200,8 +200,8 @@ def test_multi_process(self):
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='avg',
any_or_all='any',
min_ratio=0.2,
max_ratio=0.9)
min_score=0.2,
max_score=0.9)
self._run_filter(dataset, tgt_list, op, num_proc=4)


Expand Down

0 comments on commit c84d1fa

Please sign in to comment.