Skip to content

Commit

Permalink
Merge pull request #87 from jiayuanz3/main
Browse files Browse the repository at this point in the history
add LIDC datasetloader, update environment, modify random_click func
  • Loading branch information
WuJunde authored Jan 22, 2024
2 parents 57ee54b + 0241583 commit df82a93
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 144 deletions.
95 changes: 85 additions & 10 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __getitem__(self, index):
# else:
# inout = 1
# point_label = 1
inout = 1
point_label = 1

"""Get the images"""
Expand All @@ -69,7 +68,7 @@ def __getitem__(self, index):
mask = mask.resize(newsize)

if self.prompt == 'click':
pt = random_click(np.array(mask) / 255, point_label, inout)
point_label, pt = random_click(np.array(mask) / 255, point_label)

if self.transform:
state = torch.get_rng_state()
Expand Down Expand Up @@ -110,7 +109,6 @@ def __len__(self):
return len(self.subfolders)

def __getitem__(self, index):
inout = 1
point_label = 1

"""Get the images"""
Expand All @@ -132,10 +130,10 @@ def __getitem__(self, index):
multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup]
multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc]

# first click is the target agreement among all raters
# first click is the target agreement among most raters
if self.prompt == 'click':
pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout)
pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout)
point_label, pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label)
point_label, pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label)

if self.transform:
state = torch.get_rng_state()
Expand All @@ -153,16 +151,93 @@ def __getitem__(self, index):
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'multi_rater_cup': multi_rater_cup,
'multi_rater': multi_rater_cup,
'multi_rater_disc': multi_rater_disc,
'mask_cup': mask_cup,
'mask_disc': mask_disc,
'label': mask_disc,
'label': mask_cup,
'p_label':point_label,
'pt_cup':pt_cup,
'pt_disc':pt_disc,
'pt':pt_disc,
'selected_rater': torch.tensor(np.arange(7)),
'pt':pt_cup,
'image_meta_dict':image_meta_dict,
}


class LIDC(Dataset):
names = []
images = []
labels = []
series_uid = []

def __init__(self, data_path, transform=None, transform_msk = None, prompt = 'click'):
self.prompt = prompt
self.transform = transform
self.transform_msk = transform_msk

max_bytes = 2**31 - 1
data = {}
for file in os.listdir(data_path):
filename = os.fsdecode(file)
if '.pickle' in filename:
file_path = data_path + filename
bytes_in = bytearray(0)
input_size = os.path.getsize(file_path)
with open(file_path, 'rb') as f_in:
for _ in range(0, input_size, max_bytes):
bytes_in += f_in.read(max_bytes)
new_data = pickle.loads(bytes_in)
data.update(new_data)


for key, value in data.items():
self.names.append(key)
self.images.append(value['image'].astype(float))
self.labels.append(value['masks'])
self.series_uid.append(value['series_uid'])

assert (len(self.images) == len(self.labels) == len(self.series_uid))

for img in self.images:
assert np.max(img) <= 1 and np.min(img) >= 0
for label in self.labels:
assert np.max(label) <= 1 and np.min(label) >= 0

del new_data
del data

def __len__(self):
return len(self.images)

def __getitem__(self, index):

point_label = 1

"""Get the images"""
img = np.expand_dims(self.images[index], axis=0)
name = self.names[index]
multi_rater = self.labels[index]

# first click is the target most agreement among raters, otherwise, background agreement
if self.prompt == 'click':
point_label, pt = random_click(np.array(np.mean(np.stack(multi_rater), axis=0)) / 255, point_label)

# Convert image (ensure three channels) and multi-rater labels to torch tensors
img = torch.from_numpy(img).type(torch.float32)
img = img.repeat(3, 1, 1)
multi_rater = [torch.from_numpy(single_rater).type(torch.float32) for single_rater in multi_rater]

multi_rater = torch.stack(multi_rater, dim=0)
multi_rater = multi_rater.unsqueeze(1)
mask = multi_rater.mean(dim=0) # average

image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'multi_rater': multi_rater,
'label': mask,
'p_label':point_label,
'pt':pt,
'image_meta_dict':image_meta_dict,
}

Loading

0 comments on commit df82a93

Please sign in to comment.