Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IDTReeS - Clip boxes outside of image bounds #760

Merged
merged 16 commits into from
Sep 7, 2022
49 changes: 45 additions & 4 deletions torchgeo/datasets/idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from rasterio.enums import Resampling
from torch import Tensor
from torchvision.ops import clip_boxes_to_image, remove_small_boxes
from torchvision.utils import draw_bounding_boxes

from .geo import NonGeoDataset
Expand Down Expand Up @@ -211,10 +212,22 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
if self.split == "test":
if self.task == "task2":
sample["boxes"] = self._load_boxes(path)
h, w = sample["image"].shape[1:]
sample["boxes"], _ = self._filter_boxes(
image_size=(h, w), boxes=sample["boxes"]
)
else:
sample["boxes"] = self._load_boxes(path)
sample["label"] = self._load_target(path)
boxes = self._load_boxes(path)
labels = self._load_target(path)

h, w = sample["image"].shape[1:]
boxes, labels = self._filter_boxes( # type: ignore[assignment]
image_size=(h, w), boxes=boxes, labels=labels
)
sample["boxes"] = boxes
sample["label"] = labels

# Filter boxes
if self.transforms is not None:
sample = self.transforms(sample)

Expand Down Expand Up @@ -383,14 +396,42 @@ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
features: Dict[int, Dict[str, Any]] = {}
for path in filepaths:
with fiona.open(path) as src:
for i, feature in enumerate(src):
for feature in src:
if self.split == "train":
features[feature["properties"]["id"]] = feature
# Test set task 2 has no id
else:
features[i] = feature
id = len(features) + 1
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
features[id] = feature
return features

def _filter_boxes(
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
self,
image_size: Tuple[int, int],
boxes: Tensor,
min_size: int = 1,
labels: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Clip boxes to image size and filter boxes with sides less than ``min_size``.

Args:
image_size: tuple of (height, width) of image
min_size: filter boxes that have any side less than min_size
boxes: [N, 4] shape tensor of xyxy bounding box coordinates
labels: (Optional) [N,] shape tensor of bounding box labels

Returns:
a tuple of filtered boxes and labels
"""
boxes = clip_boxes_to_image(boxes=boxes, size=image_size)
indices = remove_small_boxes(boxes=boxes, min_size=min_size)

boxes = boxes[indices]
if labels is not None:
labels = labels[indices]

return boxes, labels

def _verify(self) -> None:
"""Verify the integrity of the dataset.

Expand Down