Skip to content

Commit

Permalink
Merge pull request #97 from LJQCN101/main
Browse files Browse the repository at this point in the history
Support multi-class segmentation for REFUGE dataset
  • Loading branch information
WuJunde authored Mar 6, 2024
2 parents feb9c48 + 04434fd commit 1888b84
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 27 deletions.
1 change: 1 addition & 0 deletions cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def parse_args():
parser.add_argument('-roi_size', type=int, default=96 , help='resolution of roi')
parser.add_argument('-evl_chunk', type=int, default=None , help='evaluation chunk')
parser.add_argument('-mid_dim', type=int, default=None , help='middle dim of adapter or the rank of lora matrix')
parser.add_argument('-multimask_output', type=bool, default=False , help='multi mask output for multi-class segmentation, set True for REFUGE dataset.')
parser.add_argument(
'-data_path',
type=str,
Expand Down
22 changes: 10 additions & 12 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __getitem__(self, index):


class REFUGE(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'none', plane = False):
self.data_path = data_path
self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir()]
self.mode = mode
Expand Down Expand Up @@ -132,9 +132,12 @@ def __getitem__(self, index):

# first click is the target agreement among most raters
if self.prompt == 'click':
point_label, pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label)
point_label, pt = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label)
point_label, pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label)

else:
# through experiment, you can get rid of prompts and it barely hurt the accuracy
pt = np.array([0, 0], dtype=np.int32)

if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
Expand All @@ -147,20 +150,15 @@ def __getitem__(self, index):
multi_rater_disc = torch.stack(multi_rater_disc, dim=0)
mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
torch.set_rng_state(state)

mask = torch.concat([mask_cup, mask_disc], dim=0)

image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'multi_rater': multi_rater_cup,
'multi_rater_disc': multi_rater_disc,
'mask_cup': mask_cup,
'mask_disc': mask_disc,
'label': mask_cup,
# 'label': mask_disc,
'label': mask,
'p_label':point_label,
'pt_cup':pt_cup,
'pt_disc':pt_disc,
'pt':pt_cup,
'pt':pt,
'image_meta_dict':image_meta_dict,
}

Expand Down
21 changes: 18 additions & 3 deletions function.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,22 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
labels=labels_torch,
)

if args.net == 'sam' or args.net == 'mobile_sam':
if args.net == 'sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=args.multimask_output,
)
elif args.net == 'mobile_sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)

elif args.net == "efficient_sam":
se = se.view(
se.shape[0],
Expand Down Expand Up @@ -310,7 +317,15 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
labels=labels_torch,
)

if args.net == 'sam' or args.net == 'mobile_sam':
if args.net == 'sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=args.multimask_output,
)
elif args.net == 'mobile_sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
Expand Down
4 changes: 2 additions & 2 deletions models/sam/modeling/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def forward(
)

# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
if multimask_output: # for REFUGE dataset
mask_slice = slice(0, 2)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
Expand Down
18 changes: 13 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,13 @@

for epoch in range(settings.EPOCH):
if epoch and epoch < 5:
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')

if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.')

net.train()
time_start = time.time()
loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis)
Expand All @@ -165,8 +169,12 @@

net.eval()
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1:
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.')

if args.distributed != 'none':
sd = net.module.state_dict()
Expand Down
5 changes: 2 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,12 +963,11 @@ def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = N
if reverse == True:
pred_masks = 1 - pred_masks
gt_masks = 1 - gt_masks
if c == 2:
if c == 2: # for REFUGE multi mask output
pred_disc, pred_cup = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), pred_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w)
gt_disc, gt_cup = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), gt_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w)
tup = (imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:])
# compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
compose = torch.cat((pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
compose = torch.cat(tup, 0)
vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10)
else:
imgs = torchvision.transforms.Resize((h,w))(imgs)
Expand Down
20 changes: 18 additions & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,28 @@
elif args.dataset == 'decathlon':
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args)

elif args.dataset == 'REFUGE':
'''REFUGE data'''
refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training')
refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test')

nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True)
nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
'''end'''


'''begain valuation'''
best_acc = 0.0
best_tol = 1e4

if args.mod == 'sam_adpt':
net.eval()
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, start_epoch, net)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')

if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, start_epoch, net)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
else:
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, start_epoch, net)
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {start_epoch}.')


0 comments on commit 1888b84

Please sign in to comment.