Skip to content

Commit

Permalink
IDTReeS - Clip boxes outside of image bounds (microsoft#760)
Browse files Browse the repository at this point in the history
* reverse idtrees coords

* clip boxes to image bounds

* undo coordinate reversal

* add filter_boxes function

* filter boxes outside of bounds in idtrees

* format

* Revert utils.py

* flake8 fixes

* fix mypy errors

* fix bug overriding some labels

* fix image size

* Remove version added line

* add function overloads

* add comments for clarity

* use id counter for test set
  • Loading branch information
isaaccorley authored Sep 7, 2022
1 parent 3b87caf commit 940788b
Showing 1 changed file with 63 additions and 4 deletions.
67 changes: 63 additions & 4 deletions torchgeo/datasets/idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import glob
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, cast, overload

import fiona
import matplotlib.pyplot as plt
Expand All @@ -14,6 +14,7 @@
import torch
from rasterio.enums import Resampling
from torch import Tensor
from torchvision.ops import clip_boxes_to_image, remove_small_boxes
from torchvision.utils import draw_bounding_boxes

from .geo import NonGeoDataset
Expand Down Expand Up @@ -211,10 +212,22 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
if self.split == "test":
if self.task == "task2":
sample["boxes"] = self._load_boxes(path)
h, w = sample["image"].shape[1:]
sample["boxes"], _ = self._filter_boxes(
image_size=(h, w), min_size=1, boxes=sample["boxes"], labels=None
)
else:
sample["boxes"] = self._load_boxes(path)
sample["label"] = self._load_target(path)

h, w = sample["image"].shape[1:]
sample["boxes"], sample["label"] = self._filter_boxes(
image_size=(h, w),
min_size=1,
boxes=sample["boxes"],
labels=sample["label"],
)

if self.transforms is not None:
sample = self.transforms(sample)

Expand Down Expand Up @@ -271,11 +284,15 @@ def _load_boxes(self, path: str) -> Tensor:
geometries = cast(Dict[int, Dict[str, Any]], self.geometries)

# Find object ids and geometries
# The train set geometry->image mapping is contained
# in the train/Field/itc_rsFile.csv file
if self.split == "train":
indices = self.labels["rsFile"] == base_path
ids = self.labels[indices]["id"].tolist()
geoms = [geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
# Test set - Task 2 has no mapping csv. Mapping is inside of geometry
# The test set has no mapping csv. The mapping is inside of the geometry
# properties i.e. geom["property"]["plotID"] contains the RGB image filename
# Return all geometries with the matching RGB image filename of the sample
else:
ids = [
k
Expand Down Expand Up @@ -380,17 +397,59 @@ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
"""
filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp"))

i = 0
features: Dict[int, Dict[str, Any]] = {}
for path in filepaths:
with fiona.open(path) as src:
for i, feature in enumerate(src):
for feature in src:
# The train set has a unique id for each geometry in the properties
if self.split == "train":
features[feature["properties"]["id"]] = feature
# Test set task 2 has no id
# The test set has no unique id so create a dummy id
else:
features[i] = feature
i += 1
return features

@overload
def _filter_boxes(
self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: Tensor
) -> Tuple[Tensor, Tensor]:
...

@overload
def _filter_boxes(
self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: None
) -> Tuple[Tensor, None]:
...

def _filter_boxes(
self,
image_size: Tuple[int, int],
min_size: int,
boxes: Tensor,
labels: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""Clip boxes to image size and filter boxes with sides less than ``min_size``.
Args:
image_size: tuple of (height, width) of image
min_size: filter boxes that have any side less than min_size
boxes: [N, 4] shape tensor of xyxy bounding box coordinates
labels: (Optional) [N,] shape tensor of bounding box labels
Returns:
a tuple of filtered boxes and labels
"""
boxes = clip_boxes_to_image(boxes=boxes, size=image_size)
indices = remove_small_boxes(boxes=boxes, min_size=min_size)

boxes = boxes[indices]
if labels is not None:
labels = labels[indices]

return boxes, labels

def _verify(self) -> None:
"""Verify the integrity of the dataset.
Expand Down

0 comments on commit 940788b

Please sign in to comment.