Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SegGPT] Fix loss calculation #30421

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions src/transformers/models/seggpt/modeling_seggpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,11 +753,15 @@ def forward(
bool_masked_pos: Optional[torch.BoolTensor] = None,
feature_ensemble: Optional[bool] = None,
embedding_type: Optional[str] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SegGptEncoderOutput]:
r"""
labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
Ground truth mask for input images.

Returns:

Examples:
Expand Down Expand Up @@ -799,10 +803,21 @@ def forward(

# Prepare inputs
pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2)
prompt_pixel_values = torch.cat((prompt_masks, prompt_masks), dim=2)
prompt_pixel_values = (
torch.cat((prompt_masks, prompt_masks), dim=2)
if labels is None
else torch.cat((prompt_pixel_values, labels), dim=2)
)

if bool_masked_pos is None and labels is not None:
logger.warning_once(
"Labels were provided, but bool_masked_pos is not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos."
EduardoPach marked this conversation as resolved.
Show resolved Hide resolved
)

# We concat on height axis so SegGPT can handle as a single image, hence we need to mask the portion
# of the prompt pixels that will be destinated to the prediction as they don't add any information.
# of the mask prompt pixels that will be destinated to the prediction as they don't add any information.
# This is only the case for inference. In training, the model concat of prompt mask and label is masked
# and reconstructed together (In-Context Painting).
if bool_masked_pos is None:
num_patches = self.embeddings.patch_embeddings.num_patches
bool_masked_pos = torch.zeros(num_patches, dtype=torch.bool).to(pixel_values.device)
Expand Down Expand Up @@ -840,7 +855,9 @@ def unpatchify(tensor: torch.Tensor, patch_height: int, patch_width: int) -> tor
batch_size = tensor.shape[0]
patch_size = int((tensor.shape[-1] / 3) ** 0.5)
if patch_height * patch_width != tensor.shape[1]:
raise ValueError(f"Number of patches {tensor.shape[1]} does not match patch height and width.")
raise ValueError(
f"Number of patches {tensor.shape[1]} does not match patch height ({patch_height}) and width ({patch_width})."
)

tensor = tensor.reshape(shape=(batch_size, patch_height, patch_width, patch_size, patch_size, 3))
tensor = tensor.permute(0, 5, 1, 3, 2, 4)
Expand All @@ -857,20 +874,16 @@ def __init__(self, config):

def forward(
self,
pixel_values: torch.FloatTensor,
prompt_pixel_values: torch.FloatTensor,
prompt_masks: torch.FloatTensor,
pred_masks: torch.FloatTensor,
labels: torch.FloatTensor,
bool_masked_pos: torch.BoolTensor,
):
"""Computes the L1 loss between the predicted masks and the ground truth masks.

Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
Concatenated pixel values from prompt and input images.

prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
Concatenated pixel values from mask prompt.
prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values from mask prompt.

pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
Predicted masks.
Expand All @@ -884,12 +897,12 @@ def forward(
Returns:
`torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks.
"""
ground_truth = torch.cat((prompt_masks, labels), dim=2)

mask = bool_masked_pos[:, :, None].repeat(1, 1, self.patch_size**2 * 3)
mask = unpatchify(mask, pixel_values.shape[1] // self.patch_size, pixel_values.shape[2] // self.patch_size)
# Changing dummy mask in prompt_pixel_values to labels values
prompt_pixel_values = prompt_pixel_values.clone()
prompt_pixel_values[:, :, prompt_pixel_values.shape[2] // 2 :, :] = labels
loss = F.smooth_l1_loss(pred_masks, prompt_pixel_values, reduction="none", beta=self.beta)
mask = unpatchify(mask, ground_truth.shape[2] // self.patch_size, ground_truth.shape[3] // self.patch_size)

loss = F.smooth_l1_loss(pred_masks, ground_truth, reduction="none", beta=self.beta)
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches

return loss
Expand Down Expand Up @@ -988,7 +1001,7 @@ def forward(
loss = None
if labels is not None:
loss_fn = SegGptLoss(self.config)
loss = loss_fn(pixel_values, prompt_pixel_values, pred_masks, labels, bool_masked_pos)
loss = loss_fn(prompt_masks, pred_masks, labels, bool_masked_pos)

if not return_dict:
output = (pred_masks,)
Expand Down