Skip to content

Commit

Permalink
add function overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Sep 6, 2022
1 parent 05f8f57 commit 4481114
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions torchgeo/datasets/idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import glob
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, cast, overload

import fiona
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -214,20 +214,20 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
sample["boxes"] = self._load_boxes(path)
h, w = sample["image"].shape[1:]
sample["boxes"], _ = self._filter_boxes(
image_size=(h, w), boxes=sample["boxes"]
image_size=(h, w), min_size=1, boxes=sample["boxes"], labels=None
)
else:
boxes = self._load_boxes(path)
labels = self._load_target(path)
sample["boxes"] = self._load_boxes(path)
sample["label"] = self._load_target(path)

h, w = sample["image"].shape[1:]
boxes, labels = self._filter_boxes( # type: ignore[assignment]
image_size=(h, w), boxes=boxes, labels=labels
sample["boxes"], sample["label"] = self._filter_boxes(
image_size=(h, w),
min_size=1,
boxes=sample["boxes"],
labels=sample["label"],
)
sample["boxes"] = boxes
sample["label"] = labels

# Filter boxes
if self.transforms is not None:
sample = self.transforms(sample)

Expand Down Expand Up @@ -405,12 +405,24 @@ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
features[id] = feature
return features

@overload
def _filter_boxes(
self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: Tensor
) -> Tuple[Tensor, Tensor]:
...

@overload
def _filter_boxes(
self, image_size: Tuple[int, int], min_size: int, boxes: Tensor, labels: None
) -> Tuple[Tensor, None]:
...

def _filter_boxes(
self,
image_size: Tuple[int, int],
min_size: int,
boxes: Tensor,
min_size: int = 1,
labels: Optional[Tensor] = None,
labels: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""Clip boxes to image size and filter boxes with sides less than ``min_size``.
Expand Down

0 comments on commit 4481114

Please sign in to comment.