Skip to content

Commit

Permalink
fix sample with custom dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
pedropesserl committed Oct 22, 2024
1 parent 88abd67 commit a47786b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
21 changes: 11 additions & 10 deletions guided_diffusion/custom_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self, args, data_path , transform = None, mode = 'Training',plane =
images = sorted(glob(os.path.join(path, "images/*.png")))
masks = sorted(glob(os.path.join(path, "masks/*.png")))

self.name_list = images[:2]
self.label_list = masks[:2]
self.name_list = images
self.label_list = masks
self.data_path = path
self.mode = mode

Expand All @@ -44,18 +44,19 @@ def __getitem__(self, index):
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')

if self.mode == 'Training':
label = 0 if self.label_list[index] == 'benign' else 1
else:
label = int(self.label_list[index])
# if self.mode == 'Training':
# label = 0 if self.label_list[index] == 'benign' else 1
# else:
# label = int(self.label_list[index])

if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
mask = self.transform(mask)

if self.mode == 'Training':
return (img, mask, name)
else:
return (img, mask, name)
return (img, mask, name)
# if self.mode == 'Training':
# return (img, mask, name)
# else:
# return (img, mask, name)
8 changes: 8 additions & 0 deletions scripts/segmentation_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from guided_diffusion import dist_util, logger
from guided_diffusion.bratsloader import BRATSDataset, BRATSDataset3D
from guided_diffusion.isicloader import ISICDataset
from guided_diffusion.custom_dataset_loader import CustomDataset
import torchvision.utils as vutils
from guided_diffusion.utils import staple
from guided_diffusion.script_util import (
Expand Down Expand Up @@ -58,6 +59,13 @@ def main():

ds = BRATSDataset3D(args.data_dir,transform_test)
args.in_ch = 5
else:
tran_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor()]
transform_test = transforms.Compose(tran_list)

ds = CustomDataset(args, args.data_dir, transform_test, mode = 'Test')
args.in_ch = 4

datal = th.utils.data.DataLoader(
ds,
batch_size=args.batch_size,
Expand Down

0 comments on commit a47786b

Please sign in to comment.