Skip to content

Commit

Permalink
Modify mask decoder for two-class mask output.
Browse files Browse the repository at this point in the history
  • Loading branch information
LJQCN101 authored Mar 6, 2024
1 parent bb57efb commit 04434fd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/sam/modeling/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def forward(
)

# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
if multimask_output: # for REFUGE dataset
mask_slice = slice(0, 2)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
Expand Down

0 comments on commit 04434fd

Please sign in to comment.