diff --git a/tests/task/vision/test_image_classification.py b/tests/task/vision/test_image_classification.py index f01c8246..62d2193d 100644 --- a/tests/task/vision/test_image_classification.py +++ b/tests/task/vision/test_image_classification.py @@ -17,7 +17,7 @@ def test_smoke_train(hf_cache_path): feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path="nateraw/tiny-vit-random") dm = ImageClassificationDataModule( - cfg=ImageClassificationDataConfig(batch_size=2, dataset_name="beans"), + cfg=ImageClassificationDataConfig(batch_size=1, dataset_name="beans"), feature_extractor=feature_extractor, ) model = ImageClassificationTransformer(