Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 29, 2021
1 parent ccf25c3 commit 0f7df42
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/image/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ def test_from_filepaths_smoke(tmpdir):
def test_from_data_frame_smoke(tmpdir):
tmpdir = Path(tmpdir)

df = pd.DataFrame({"file": ["train.png", "valid.png", "test.png"], "split": ["train", "valid", "test"],
"target": [0, 1, 1]})
df = pd.DataFrame(
{"file": ["train.png", "valid.png", "test.png"], "split": ["train", "valid", "test"], "target": [0, 1, 1]}
)

[_rand_image().save(tmpdir / row.file) for i, row in df.iterrows()]

img_data = ImageClassificationData.from_data_frame(
"file", "target",
"file",
"target",
train_images_root=str(tmpdir),
val_images_root=str(tmpdir),
test_images_root=str(tmpdir),
Expand All @@ -104,7 +106,8 @@ def test_from_data_frame_smoke(tmpdir):
test_data_frame=df[df.split == "test"],
predict_images_root=str(tmpdir),
batch_size=1,
predict_data_frame=df)
predict_data_frame=df,
)

assert img_data.train_dataloader() is not None
assert img_data.val_dataloader() is not None
Expand Down

0 comments on commit 0f7df42

Please sign in to comment.