Skip to content

Commit

Permalink
local tests pass, but need a test for when there are both empty and n…
Browse files Browse the repository at this point in the history
…on-empty mixed
  • Loading branch information
bw4sz committed Dec 21, 2024
1 parent e9ddebc commit 8230454
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,23 @@ def validation_step(self, batch, batch_idx):
for index, result in enumerate(preds):
# Skip empty predictions
if result["boxes"].shape[0] == 0:
continue
boxes = visualize.format_geometry(result)
boxes["image_path"] = path[index]
self.predictions.append(boxes)
self.predictions.append(
pd.DataFrame(
{
"image_path": [path[index]],
"xmin": [None],
"ymin": [None],
"xmax": [None],
"ymax": [None],
"label": [None],
"score": [None]
}
)
)
else:
boxes = visualize.format_geometry(result)
boxes["image_path"] = path[index]
self.predictions.append(boxes)

return losses

Expand Down Expand Up @@ -731,13 +744,14 @@ def calculate_empty_frame_accuracy(self, ground_df, predictions_df):
if len(empty_images) == 0:
return None

# Get predictions for empty images
empty_predictions = predictions_df.loc[predictions_df.image_path.isin(empty_images)]
# Get non-empty predictions for empty images
non_empty_predictions = predictions_df.loc[predictions_df.xmin.notnull()]
predictions_for_empty_images = non_empty_predictions.loc[non_empty_predictions.image_path.isin(empty_images)]

# Create prediction tensor - 1 if model predicted objects, 0 if predicted empty
predictions = torch.zeros(len(empty_images))
for index, image in enumerate(empty_images):
if len(empty_predictions.loc[empty_predictions.image_path == image]) > 0:
if len(predictions_for_empty_images.loc[predictions_for_empty_images.image_path == image]) > 0:
predictions[index] = 1

# Ground truth tensor - all zeros since these are empty frames
Expand Down

0 comments on commit 8230454

Please sign in to comment.