Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jul 13, 2021
1 parent 787ddb8 commit 7121dfb
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions tests/image/segmentation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_from_folders(tmpdir):
]

num_classes: int = 2
img_size: Tuple[int, int] = (196, 196)
img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)

# instantiate the data module
Expand All @@ -110,20 +110,20 @@ def test_from_folders(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check val data
data = next(iter(dm.val_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

@staticmethod
def test_from_folders_warning(tmpdir):
Expand All @@ -145,7 +145,7 @@ def test_from_folders_warning(tmpdir):
]

num_classes: int = 2
img_size: Tuple[int, int] = (196, 196)
img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)

# instantiate the data module
Expand All @@ -164,8 +164,8 @@ def test_from_folders_warning(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (1, 3, 196, 196)
assert labels.shape == (1, 196, 196)
assert imgs.shape == (1, 3, 128, 128)
assert labels.shape == (1, 128, 128)

@staticmethod
def test_from_files(tmpdir):
Expand All @@ -186,7 +186,7 @@ def test_from_files(tmpdir):
]

num_classes: int = 2
img_size: Tuple[int, int] = (196, 196)
img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)

# instantiate the data module
Expand All @@ -210,20 +210,20 @@ def test_from_files(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check val data
data = next(iter(dm.val_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

@staticmethod
def test_from_files_warning(tmpdir):
Expand All @@ -244,7 +244,7 @@ def test_from_files_warning(tmpdir):
]

num_classes: int = 2
img_size: Tuple[int, int] = (196, 196)
img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)

# instantiate the data module
Expand Down Expand Up @@ -272,7 +272,7 @@ def test_from_fiftyone(tmpdir):
]

num_classes: int = 2
img_size: Tuple[int, int] = (196, 196)
img_size: Tuple[int, int] = (128, 128)

for img_file in images:
_rand_image(img_size).save(img_file)
Expand Down Expand Up @@ -307,25 +307,25 @@ def test_from_fiftyone(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check val data
data = next(iter(dm.val_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 196, 196)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check predict data
data = next(iter(dm.predict_dataloader()))
imgs = data[DefaultDataKeys.INPUT]
assert imgs.shape == (2, 3, 196, 196)
assert imgs.shape == (2, 3, 128, 128)

@staticmethod
def test_map_labels(tmpdir):
Expand All @@ -351,7 +351,7 @@ def test_map_labels(tmpdir):
}

num_classes: int = len(labels_map.keys())
img_size: Tuple[int, int] = (256, 256)
img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)

# instantiate the data module
Expand Down Expand Up @@ -379,8 +379,8 @@ def test_map_labels(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 256, 256)
assert labels.shape == (2, 256, 256)
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)
assert labels.min().item() == 0
assert labels.max().item() == 1
assert labels.dtype == torch.int64
Expand Down

0 comments on commit 7121dfb

Please sign in to comment.