diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 708db5c22..25e518f2a 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -120,6 +120,10 @@ process: min_ratio: 0.333 # the min aspect ratio of filter range max_ratio: 3.0 # the max aspect ratio of filter range any_or_all: any # keep this sample when any/all images meet the filter condition + - image_size_filter: # filter samples according to the size of images (in bytes) within them + min_size: "0" # the min size of filter range + max_ratio: "1TB" # the max size of filter range + any_or_all: any # keep this sample when any/all images meet the filter condition - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value lang: en # keep text in what language min_score: 0.8 # the min language scores to filter text diff --git a/data_juicer/ops/filter/image_size_filter.py b/data_juicer/ops/filter/image_size_filter.py index e5fd9455e..79e513ee3 100644 --- a/data_juicer/ops/filter/image_size_filter.py +++ b/data_juicer/ops/filter/image_size_filter.py @@ -16,7 +16,7 @@ class ImageSizeFilter(Filter): def __init__(self, min_size: str = '0', - max_size: str = '1Tb', + max_size: str = '1TB', any_or_all: str = 'any', *args, **kwargs): @@ -35,8 +35,8 @@ def __init__(self, :param kwargs: extra args """ super().__init__(*args, **kwargs) - self.min_size = min_size - self.max_size = max_size + self.min_size = size_to_bytes(min_size) + self.max_size = size_to_bytes(max_size) if any_or_all not in ['any', 'all']: raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' f'Can only be one of ["any", "all"].') @@ -63,8 +63,8 @@ def compute_stats(self, sample, context=False): def process(self, sample): image_sizes = sample[Fields.stats][StatsKeys.image_sizes] keep_bools = np.array([ - size_to_bytes(self.min_size) <= image_size <= size_to_bytes( - self.max_size) for image_size in image_sizes + self.min_size <= image_size <= self.max_size + for image_size in image_sizes ]) if len(keep_bools) <= 0: return True diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 0bcf8590e..ea6b2063f 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -63,4 +63,5 @@ def size_to_bytes(size): else: raise ValueError(f'You specified unidentifiable unit: {suffix}, ' f'expected in [KB, MB, GB, TB, PB, EB, ZB, YB, ' - f'KiB, MiB, GiB, TiB, PiB, EiB, ZiB, YiB]') + f'KiB, MiB, GiB, TiB, PiB, EiB, ZiB, YiB], ' + f'(case insensitive, counted by *Bytes*).')