Skip to content

Commit

Permalink
Fixed main train issues
Browse files Browse the repository at this point in the history
  • Loading branch information
EduardoPach committed Apr 23, 2024
1 parent 77b59dc commit 5b001b8
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions src/transformers/models/seggpt/modeling_seggpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,11 +753,15 @@ def forward(
bool_masked_pos: Optional[torch.BoolTensor] = None,
feature_ensemble: Optional[bool] = None,
embedding_type: Optional[str] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SegGptEncoderOutput]:
r"""
labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
Ground truth mask for input images.
Returns:
Examples:
Expand Down Expand Up @@ -799,10 +803,21 @@ def forward(

# Prepare inputs
pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2)
prompt_pixel_values = torch.cat((prompt_masks, prompt_masks), dim=2)
prompt_pixel_values = (
torch.cat((prompt_masks, prompt_masks), dim=2)
if labels is None
else torch.cat((prompt_pixel_values, labels), dim=2)
)

if bool_masked_pos is None and labels is not None:
logger.warning_once(
"Labels were provided, but bool_masked_pos is not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos."
)

# We concat on height axis so SegGPT can handle as a single image, hence we need to mask the portion
# of the prompt pixels that will be destinated to the prediction as they don't add any information.
# of the mask prompt pixels that will be destinated to the prediction as they don't add any information.
# This is only the case for inference. In training, the model concat of prompt mask and label is masked
# and reconstructed together (In-Context Painting).
if bool_masked_pos is None:
num_patches = self.embeddings.patch_embeddings.num_patches
bool_masked_pos = torch.zeros(num_patches, dtype=torch.bool).to(pixel_values.device)
Expand Down Expand Up @@ -840,7 +855,9 @@ def unpatchify(tensor: torch.Tensor, patch_height: int, patch_width: int) -> tor
batch_size = tensor.shape[0]
patch_size = int((tensor.shape[-1] / 3) ** 0.5)
if patch_height * patch_width != tensor.shape[1]:
raise ValueError(f"Number of patches {tensor.shape[1]} does not match patch height and width.")
raise ValueError(
f"Number of patches {tensor.shape[1]} does not match patch height ({patch_height}) and width ({patch_width})."
)

tensor = tensor.reshape(shape=(batch_size, patch_height, patch_width, patch_size, patch_size, 3))
tensor = tensor.permute(0, 5, 1, 3, 2, 4)
Expand All @@ -857,20 +874,16 @@ def __init__(self, config):

def forward(
self,
pixel_values: torch.FloatTensor,
prompt_pixel_values: torch.FloatTensor,
prompt_masks: torch.FloatTensor,
pred_masks: torch.FloatTensor,
labels: torch.FloatTensor,
bool_masked_pos: torch.BoolTensor,
):
"""Computes the L1 loss between the predicted masks and the ground truth masks.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
Concatenated pixel values from prompt and input images.
prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
Concatenated pixel values from mask prompt.
prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values from mask prompt.
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
Predicted masks.
Expand All @@ -884,12 +897,12 @@ def forward(
Returns:
`torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks.
"""
ground_truth = torch.cat((prompt_masks, labels), dim=2)

mask = bool_masked_pos[:, :, None].repeat(1, 1, self.patch_size**2 * 3)
mask = unpatchify(mask, pixel_values.shape[1] // self.patch_size, pixel_values.shape[2] // self.patch_size)
# Changing dummy mask in prompt_pixel_values to labels values
prompt_pixel_values = prompt_pixel_values.clone()
prompt_pixel_values[:, :, prompt_pixel_values.shape[2] // 2 :, :] = labels
loss = F.smooth_l1_loss(pred_masks, prompt_pixel_values, reduction="none", beta=self.beta)
mask = unpatchify(mask, ground_truth.shape[2] // self.patch_size, ground_truth.shape[3] // self.patch_size)

loss = F.smooth_l1_loss(pred_masks, ground_truth, reduction="none", beta=self.beta)
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches

return loss
Expand Down Expand Up @@ -988,7 +1001,7 @@ def forward(
loss = None
if labels is not None:
loss_fn = SegGptLoss(self.config)
loss = loss_fn(pixel_values, prompt_pixel_values, pred_masks, labels, bool_masked_pos)
loss = loss_fn(prompt_masks, pred_masks, labels, bool_masked_pos)

if not return_dict:
output = (pred_masks,)
Expand Down

0 comments on commit 5b001b8

Please sign in to comment.