From 9e7ede97660ba4fa23ff577990debf2063cae56a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 6 Jan 2022 12:34:53 +0000 Subject: [PATCH] Fixes --- flash/image/segmentation/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 5e1b4e4df5..9324620b40 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -117,7 +117,7 @@ def from_files( >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) - >>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64, 1), dtype="uint8")) + >>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8")) >>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)] >>> _ = [rand_mask.save(f"mask_{i}.png") for i in range(1, 4)] >>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)] @@ -262,7 +262,7 @@ def from_folders( >>> import os >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) - >>> rand_mask = Image.fromarray(np.random.randint(0, 10, (64, 64, 1), dtype="uint8")) + >>> rand_mask = Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8")) >>> os.makedirs("train_images", exist_ok=True) >>> os.makedirs("train_masks", exist_ok=True) >>> os.makedirs("predict_folder", exist_ok=True)