From 0f67ba1d741d65b07d549daf4ee157609ce4f9c1 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 25 Jun 2024 13:36:58 +0100 Subject: [PATCH] Add ViTImageProcessorFast to tests (#31424) * Add ViTImageProcessor to tests * Correct data format * Review comments --- src/transformers/image_processing_utils.py | 5 + .../image_processing_utils_fast.py | 5 + .../models/auto/image_processing_auto.py | 11 +- .../image_processing_mask2former.py | 2 +- .../maskformer/image_processing_maskformer.py | 2 +- .../oneformer/image_processing_oneformer.py | 2 +- .../models/vit/image_processing_vit_fast.py | 3 +- .../test_image_processing_bridgetower.py | 4 + .../test_image_processing_conditional_detr.py | 4 + .../test_image_processing_deformable_detr.py | 4 + .../models/detr/test_image_processing_detr.py | 4 + .../models/glpn/test_image_processing_glpn.py | 2 + .../test_image_processing_grounding_dino.py | 4 + .../idefics/test_image_processing_idefics.py | 4 + .../test_image_processing_idefics2.py | 158 +++++++++++------- .../test_image_processing_mask2former.py | 2 + .../test_image_processing_maskformer.py | 2 + .../test_image_processing_oneformer.py | 2 + .../oneformer/test_processor_oneformer.py | 2 + .../test_image_processing_pix2struct.py | 4 +- .../swin2sr/test_image_processing_swin2sr.py | 4 +- .../test_image_processing_video_llava.py | 4 +- .../models/vilt/test_image_processing_vilt.py | 4 + tests/models/vit/test_image_processing_vit.py | 6 +- .../yolos/test_image_processing_yolos.py | 3 + tests/test_image_processing_common.py | 82 +++++++-- 26 files changed, 236 insertions(+), 93 deletions(-) diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index 4b263446b54e2a..0279f26a963e35 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -151,6 +151,11 @@ def center_crop( **kwargs, ) + def to_dict(self): + encoder_dict = super().to_dict() + encoder_dict.pop("_valid_processor_keys", None) + return encoder_dict + VALID_SIZE_DICT_KEYS = ( {"height", "width"}, diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index daeee3e1bd5bba..d1a08132d73d89 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -61,3 +61,8 @@ def _validate_params(self, **kwargs) -> None: def get_transforms(self, **kwargs) -> "Compose": self._validate_params(**kwargs) return self._build_transforms(**kwargs) + + def to_dict(self): + encoder_dict = super().to_dict() + encoder_dict.pop("_transform_params", None) + return encoder_dict diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index bd264b370bad6c..642c7bffcb666d 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -399,7 +399,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): kwargs["token"] = use_auth_token config = kwargs.pop("config", None) - use_fast = kwargs.pop("use_fast", False) + use_fast = kwargs.pop("use_fast", None) trust_remote_code = kwargs.pop("trust_remote_code", None) kwargs["_from_auto"] = True @@ -430,10 +430,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): if image_processor_class is not None: # Update class name to reflect the use_fast option. If class is not found, None is returned. - if use_fast and not image_processor_class.endswith("Fast"): - image_processor_class += "Fast" - elif not use_fast and image_processor_class.endswith("Fast"): - image_processor_class = image_processor_class[:-4] + if use_fast is not None: + if use_fast and not image_processor_class.endswith("Fast"): + image_processor_class += "Fast" + elif not use_fast and image_processor_class.endswith("Fast"): + image_processor_class = image_processor_class[:-4] image_processor_class = image_processor_class_from_name(image_processor_class) has_remote_code = image_processor_auto_map is not None diff --git a/src/transformers/models/mask2former/image_processing_mask2former.py b/src/transformers/models/mask2former/image_processing_mask2former.py index 6f35579978bdb8..695ae654ccba3d 100644 --- a/src/transformers/models/mask2former/image_processing_mask2former.py +++ b/src/transformers/models/mask2former/image_processing_mask2former.py @@ -772,7 +772,7 @@ def preprocess( ignore_index, do_reduce_labels, return_tensors, - input_data_format=input_data_format, + input_data_format=data_format, ) return encoded_inputs diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index e32722b074c4c1..73b428e0bab26d 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -772,7 +772,7 @@ def preprocess( ignore_index, do_reduce_labels, return_tensors, - input_data_format=input_data_format, + input_data_format=data_format, ) return encoded_inputs diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py index 674a168e09491b..6936f088bfeeab 100644 --- a/src/transformers/models/oneformer/image_processing_oneformer.py +++ b/src/transformers/models/oneformer/image_processing_oneformer.py @@ -772,7 +772,7 @@ def preprocess( ignore_index, do_reduce_labels, return_tensors, - input_data_format=input_data_format, + input_data_format=data_format, ) return encoded_inputs diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py index 09113761655b93..21f5a99a3e3d78 100644 --- a/src/transformers/models/vit/image_processing_vit_fast.py +++ b/src/transformers/models/vit/image_processing_vit_fast.py @@ -114,7 +114,6 @@ def __init__( self.rescale_factor = rescale_factor self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD - self._transform_settings = {} def _build_transforms( self, @@ -285,5 +284,5 @@ def preprocess( ) transformed_images = [transforms(image) for image in images] - data = {"pixel_values": torch.vstack(transformed_images)} + data = {"pixel_values": torch.stack(transformed_images, dim=0)} return BatchFeature(data, tensor_type=return_tensors) diff --git a/tests/models/bridgetower/test_image_processing_bridgetower.py b/tests/models/bridgetower/test_image_processing_bridgetower.py index 1dc5419b77c886..48268c8d3f5696 100644 --- a/tests/models/bridgetower/test_image_processing_bridgetower.py +++ b/tests/models/bridgetower/test_image_processing_bridgetower.py @@ -17,6 +17,8 @@ import unittest from typing import Dict, List, Optional, Union +import numpy as np + from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available @@ -84,6 +86,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] scale = size / min(w, h) diff --git a/tests/models/conditional_detr/test_image_processing_conditional_detr.py b/tests/models/conditional_detr/test_image_processing_conditional_detr.py index 171ec2d44f499a..99a06613e141bb 100644 --- a/tests/models/conditional_detr/test_image_processing_conditional_detr.py +++ b/tests/models/conditional_detr/test_image_processing_conditional_detr.py @@ -18,6 +18,8 @@ import pathlib import unittest +import numpy as np + from transformers.testing_utils import require_torch, require_vision, slow from transformers.utils import is_torch_available, is_vision_available @@ -87,6 +89,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] if w < h: diff --git a/tests/models/deformable_detr/test_image_processing_deformable_detr.py b/tests/models/deformable_detr/test_image_processing_deformable_detr.py index 51fbfc33f8c195..41e5a81e2f93c0 100644 --- a/tests/models/deformable_detr/test_image_processing_deformable_detr.py +++ b/tests/models/deformable_detr/test_image_processing_deformable_detr.py @@ -18,6 +18,8 @@ import pathlib import unittest +import numpy as np + from transformers.testing_utils import require_torch, require_vision, slow from transformers.utils import is_torch_available, is_vision_available @@ -87,6 +89,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] if w < h: diff --git a/tests/models/detr/test_image_processing_detr.py b/tests/models/detr/test_image_processing_detr.py index fc6d5651272459..4174df0f8cc792 100644 --- a/tests/models/detr/test_image_processing_detr.py +++ b/tests/models/detr/test_image_processing_detr.py @@ -17,6 +17,8 @@ import pathlib import unittest +import numpy as np + from transformers.testing_utils import require_torch, require_vision, slow from transformers.utils import is_torch_available, is_vision_available @@ -86,6 +88,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] if w < h: diff --git a/tests/models/glpn/test_image_processing_glpn.py b/tests/models/glpn/test_image_processing_glpn.py index abffb31a66936c..d4aa78656af537 100644 --- a/tests/models/glpn/test_image_processing_glpn.py +++ b/tests/models/glpn/test_image_processing_glpn.py @@ -66,6 +66,8 @@ def prepare_image_processor_dict(self): def expected_output_image_shape(self, images): if isinstance(images[0], Image.Image): width, height = images[0].size + elif isinstance(images[0], np.ndarray): + height, width = images[0].shape[0], images[0].shape[1] else: height, width = images[0].shape[1], images[0].shape[2] diff --git a/tests/models/grounding_dino/test_image_processing_grounding_dino.py b/tests/models/grounding_dino/test_image_processing_grounding_dino.py index 68618fb256aa7a..5a28397847079f 100644 --- a/tests/models/grounding_dino/test_image_processing_grounding_dino.py +++ b/tests/models/grounding_dino/test_image_processing_grounding_dino.py @@ -18,6 +18,8 @@ import pathlib import unittest +import numpy as np + from transformers.testing_utils import require_torch, require_vision, slow from transformers.utils import is_torch_available, is_vision_available @@ -93,6 +95,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] if w < h: diff --git a/tests/models/idefics/test_image_processing_idefics.py b/tests/models/idefics/test_image_processing_idefics.py index 0273480333f1be..cb2b294bf54844 100644 --- a/tests/models/idefics/test_image_processing_idefics.py +++ b/tests/models/idefics/test_image_processing_idefics.py @@ -16,6 +16,8 @@ import unittest +import numpy as np + from transformers.testing_utils import require_torch, require_torchvision, require_vision from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available @@ -75,6 +77,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] scale = size / min(w, h) diff --git a/tests/models/idefics2/test_image_processing_idefics2.py b/tests/models/idefics2/test_image_processing_idefics2.py index 2e0d36e75c8a08..624fdd6c98b3e5 100644 --- a/tests/models/idefics2/test_image_processing_idefics2.py +++ b/tests/models/idefics2/test_image_processing_idefics2.py @@ -99,6 +99,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] @@ -176,6 +178,10 @@ def prepare_image_inputs( if torchify: images_list = [[torch.from_numpy(image) for image in images] for images in images_list] + if numpify: + # Numpy images are typically in channels last format + images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list] + return images_list @@ -206,66 +212,100 @@ def test_image_processor_properties(self): self.assertTrue(hasattr(image_processing, "do_image_splitting")) def test_call_numpy(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) - for sample_images in image_inputs: - for image in sample_images: - self.assertIsInstance(image, np.ndarray) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) - self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) - self.assertEqual( - tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) - ) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for sample_images in image_inputs: + for image in sample_images: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_call_numpy_4_channels(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processor_dict = self.image_processor_dict + image_processor_dict["image_mean"] = [0.5, 0.5, 0.5, 0.5] + image_processor_dict["image_std"] = [0.5, 0.5, 0.5, 0.5] + image_processing = self.image_processing_class(**image_processor_dict) + # create random numpy tensors + self.image_processor_tester.num_channels = 4 + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + + for sample_images in image_inputs: + for image in sample_images: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing( + image_inputs[0], input_data_format="channels_last", return_tensors="pt" + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing( + image_inputs, input_data_format="channels_last", return_tensors="pt" + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - for images in image_inputs: - for image in images: - self.assertIsInstance(image, Image.Image) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) - self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) - self.assertEqual( - tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) - ) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for images in image_inputs: + for image in images: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) def test_call_pytorch(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - - for images in image_inputs: - for image in images: - self.assertIsInstance(image, torch.Tensor) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) - self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) - - # Test batched - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - self.assertEqual( - tuple(encoded_images.shape), - (self.image_processor_tester.batch_size, *expected_output_image_shape), - ) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + for images in image_inputs: + for image in images: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + tuple(encoded_images.shape), + (self.image_processor_tester.batch_size, *expected_output_image_shape), + ) diff --git a/tests/models/mask2former/test_image_processing_mask2former.py b/tests/models/mask2former/test_image_processing_mask2former.py index ae0fff89069054..98ffd906e5bf6e 100644 --- a/tests/models/mask2former/test_image_processing_mask2former.py +++ b/tests/models/mask2former/test_image_processing_mask2former.py @@ -98,6 +98,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] if w < h: diff --git a/tests/models/maskformer/test_image_processing_maskformer.py b/tests/models/maskformer/test_image_processing_maskformer.py index 5d30431f1f2aad..23e517a32626f7 100644 --- a/tests/models/maskformer/test_image_processing_maskformer.py +++ b/tests/models/maskformer/test_image_processing_maskformer.py @@ -98,6 +98,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] if w < h: diff --git a/tests/models/oneformer/test_image_processing_oneformer.py b/tests/models/oneformer/test_image_processing_oneformer.py index e60cc31b30feee..a9dcdc2cfb9063 100644 --- a/tests/models/oneformer/test_image_processing_oneformer.py +++ b/tests/models/oneformer/test_image_processing_oneformer.py @@ -106,6 +106,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] if w < h: diff --git a/tests/models/oneformer/test_processor_oneformer.py b/tests/models/oneformer/test_processor_oneformer.py index a8852ef30eb0ff..e86341ef0ee087 100644 --- a/tests/models/oneformer/test_processor_oneformer.py +++ b/tests/models/oneformer/test_processor_oneformer.py @@ -143,6 +143,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] if w < h: diff --git a/tests/models/pix2struct/test_image_processing_pix2struct.py b/tests/models/pix2struct/test_image_processing_pix2struct.py index 09e1abd8068989..fee18814f0f755 100644 --- a/tests/models/pix2struct/test_image_processing_pix2struct.py +++ b/tests/models/pix2struct/test_image_processing_pix2struct.py @@ -232,7 +232,7 @@ def test_call_numpy_4_channels(self): for max_patch in self.image_processor_tester.max_patches: # Test not batched input encoded_images = image_processor( - image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_first" + image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_last" ).flattened_patches self.assertEqual( encoded_images.shape, @@ -241,7 +241,7 @@ def test_call_numpy_4_channels(self): # Test batched encoded_images = image_processor( - image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_first" + image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_last" ).flattened_patches self.assertEqual( encoded_images.shape, diff --git a/tests/models/swin2sr/test_image_processing_swin2sr.py b/tests/models/swin2sr/test_image_processing_swin2sr.py index 732a7e95412a88..86f7c8878ca5b1 100644 --- a/tests/models/swin2sr/test_image_processing_swin2sr.py +++ b/tests/models/swin2sr/test_image_processing_swin2sr.py @@ -72,6 +72,8 @@ def expected_output_image_shape(self, images): if isinstance(img, Image.Image): input_width, input_height = img.size + elif isinstance(img, np.ndarray): + input_height, input_width = img.shape[-3:-1] else: input_height, input_width = img.shape[-2:] @@ -160,7 +162,7 @@ def test_call_numpy_4_channels(self): # Test not batched input encoded_images = image_processing( - image_inputs[0], return_tensors="pt", input_data_format="channels_first" + image_inputs[0], return_tensors="pt", input_data_format="channels_last" ).pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) diff --git a/tests/models/video_llava/test_image_processing_video_llava.py b/tests/models/video_llava/test_image_processing_video_llava.py index 808001d2814def..4a5c2516267e13 100644 --- a/tests/models/video_llava/test_image_processing_video_llava.py +++ b/tests/models/video_llava/test_image_processing_video_llava.py @@ -285,7 +285,7 @@ def test_call_numpy_4_channels(self): encoded_images = image_processor( image_inputs[0], return_tensors="pt", - input_data_format="channels_first", + input_data_format="channels_last", image_mean=0, image_std=1, ).pixel_values_images @@ -296,7 +296,7 @@ def test_call_numpy_4_channels(self): encoded_images = image_processor( image_inputs, return_tensors="pt", - input_data_format="channels_first", + input_data_format="channels_last", image_mean=0, image_std=1, ).pixel_values_images diff --git a/tests/models/vilt/test_image_processing_vilt.py b/tests/models/vilt/test_image_processing_vilt.py index f68b2d2628ad7c..25026cb7d7a462 100644 --- a/tests/models/vilt/test_image_processing_vilt.py +++ b/tests/models/vilt/test_image_processing_vilt.py @@ -16,6 +16,8 @@ import unittest +import numpy as np + from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available @@ -78,6 +80,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): w, h = image.size + elif isinstance(image, np.ndarray): + h, w = image.shape[0], image.shape[1] else: h, w = image.shape[1], image.shape[2] scale = size / min(w, h) diff --git a/tests/models/vit/test_image_processing_vit.py b/tests/models/vit/test_image_processing_vit.py index 1c376f55aa3e98..6d296654b8e8ff 100644 --- a/tests/models/vit/test_image_processing_vit.py +++ b/tests/models/vit/test_image_processing_vit.py @@ -17,7 +17,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -25,6 +25,9 @@ if is_vision_available(): from transformers import ViTImageProcessor +if is_torchvision_available(): + from transformers import ViTImageProcessorFast + class ViTImageProcessingTester(unittest.TestCase): def __init__( @@ -82,6 +85,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = ViTImageProcessor if is_vision_available() else None + fast_image_processing_class = ViTImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() diff --git a/tests/models/yolos/test_image_processing_yolos.py b/tests/models/yolos/test_image_processing_yolos.py index a94cd8b883e834..1080edd2654ed3 100644 --- a/tests/models/yolos/test_image_processing_yolos.py +++ b/tests/models/yolos/test_image_processing_yolos.py @@ -18,6 +18,7 @@ import pathlib import unittest +import numpy as np from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision, slow @@ -89,6 +90,8 @@ def get_expected_values(self, image_inputs, batched=False): image = image_inputs[0] if isinstance(image, Image.Image): width, height = image.size + elif isinstance(image, np.ndarray): + height, width = image.shape[0], image.shape[1] else: height, width = image.shape[1], image.shape[2] diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index e9b9467f580a2a..a36da4012b6d70 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -18,7 +18,9 @@ import os import pathlib import tempfile +import time +import numpy as np import requests from transformers import AutoImageProcessor, BatchFeature @@ -28,7 +30,6 @@ if is_torch_available(): - import numpy as np import torch if is_vision_available(): @@ -72,6 +73,10 @@ def prepare_image_inputs( if torchify: image_inputs = [torch.from_numpy(image) for image in image_inputs] + if numpify: + # Numpy images are typically in channels last format + image_inputs = [image.transpose(1, 2, 0) for image in image_inputs] + return image_inputs @@ -167,33 +172,28 @@ def test_slow_fast_equivalence(self): encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") - self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-3)) + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-2)) @require_vision @require_torch def test_fast_is_faster_than_slow(self): - import time - - def measure_time(self, image_processor, dummy_image): - start = time.time() - _ = image_processor(dummy_image, return_tensors="pt") - return time.time() - start - - dummy_image = Image.open( - requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw - ) - if not self.test_slow_image_processor or not self.test_fast_image_processor: self.skipTest("Skipping speed test") if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest("Skipping speed test as one of the image processors is not defined") + def measure_time(image_processor, image): + start = time.time() + _ = image_processor(image, return_tensors="pt") + return time.time() - start + + dummy_images = torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8) image_processor_slow = self.image_processing_class(**self.image_processor_dict) - image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class() - slow_time = self.measure_time(image_processor_slow, dummy_image) - fast_time = self.measure_time(image_processor_fast, dummy_image) + fast_time = measure_time(image_processor_fast, dummy_images) + slow_time = measure_time(image_processor_slow, dummy_images) self.assertLessEqual(fast_time, slow_time) @@ -238,6 +238,52 @@ def test_image_processor_save_load_with_autoimageprocessor(self): self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) + def test_save_load_fast_slow(self): + "Test that we can load a fast image processor from a slow one and vice-versa." + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined") + + image_processor_dict = self.image_processor_tester.prepare_image_processor_dict() + image_processor_slow_0 = self.image_processing_class(**image_processor_dict) + + # Load fast image processor from slow one + with tempfile.TemporaryDirectory() as tmpdirname: + image_processor_slow_0.save_pretrained(tmpdirname) + image_processor_fast_0 = self.fast_image_processing_class.from_pretrained(tmpdirname) + + image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict) + + # Load slow image processor from fast one + with tempfile.TemporaryDirectory() as tmpdirname: + image_processor_fast_1.save_pretrained(tmpdirname) + image_processor_slow_1 = self.image_processing_class.from_pretrained(tmpdirname) + + self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict()) + self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict()) + + def test_save_load_fast_slow_auto(self): + "Test that we can load a fast image processor from a slow one and vice-versa using AutoImageProcessor." + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined") + + image_processor_dict = self.image_processor_tester.prepare_image_processor_dict() + image_processor_slow_0 = self.image_processing_class(**image_processor_dict) + + # Load fast image processor from slow one + with tempfile.TemporaryDirectory() as tmpdirname: + image_processor_slow_0.save_pretrained(tmpdirname) + image_processor_fast_0 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=True) + + image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict) + + # Load slow image processor from fast one + with tempfile.TemporaryDirectory() as tmpdirname: + image_processor_fast_1.save_pretrained(tmpdirname) + image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False) + + self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict()) + self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict()) + def test_init_without_params(self): for image_processing_class in self.image_processor_list: image_processor = image_processing_class() @@ -358,7 +404,7 @@ def test_call_numpy_4_channels(self): encoded_images = image_processor( image_inputs[0], return_tensors="pt", - input_data_format="channels_first", + input_data_format="channels_last", image_mean=0, image_std=1, ).pixel_values @@ -369,7 +415,7 @@ def test_call_numpy_4_channels(self): encoded_images = image_processor( image_inputs, return_tensors="pt", - input_data_format="channels_first", + input_data_format="channels_last", image_mean=0, image_std=1, ).pixel_values