Skip to content

Commit

Permalink
Fix mypy
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim committed Apr 7, 2023
1 parent 1aea9bb commit 774bf99
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pylint: disable=invalid-name, too-many-locals, no-member

from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import numpy as np
from mmcls.core import average_performance, mAP
Expand Down Expand Up @@ -108,7 +108,7 @@ def __getitem__(self, index: int):
return data_info
return self.pipeline(data_info)

def _get_label_id(self, gt_label: np.ndarray) -> ID:
def _get_label_id(self, gt_label: np.ndarray) -> Union[ID, List[ID]]:
return self.idx_to_label_id.get(gt_label.item(), ID())

def get_gt_labels(self):
Expand Down Expand Up @@ -293,7 +293,7 @@ def evaluate(

return eval_results

def _get_label_id(self, gt_label: np.ndarray) -> List[ID]:
def _get_label_id(self, gt_label: np.ndarray) -> Union[ID, List[ID]]:
return [self.idx_to_label_id.get(idx, ID()) for idx, v in enumerate(gt_label) if v == 1]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MultiClassClsLossDynamicsTracker(LossDynamicsTracker):
def __init__(self) -> None:
super().__init__()

def init_with_otx_dataset(self, otx_dataset: DatasetEntity) -> None:
def init_with_otx_dataset(self, otx_dataset: DatasetEntity[DatasetItemEntityWithID]) -> None:
"""DatasetEntity should be injected to the tracker for the initialization."""
otx_labels = otx_dataset.get_labels()
label_categories = dm.LabelCategories.from_iterable([label_entity.name for label_entity in otx_labels])
Expand Down
21 changes: 12 additions & 9 deletions otx/api/entities/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import itertools
import logging
from enum import Enum
from typing import Iterator, List, Optional, Union, cast, overload
from typing import Generic, Iterator, List, Optional, TypeVar, Union, cast, overload

from bson.objectid import ObjectId

Expand Down Expand Up @@ -82,7 +82,10 @@ def __next__(self) -> DatasetItemEntity:
return item


class DatasetEntity:
TDatasetItemEntity = TypeVar("TDatasetItemEntity", bound=DatasetItemEntity)


class DatasetEntity(Generic[TDatasetItemEntity]):
"""A dataset consists of a list of DatasetItemEntities and a purpose.
## With dataset items
Expand Down Expand Up @@ -146,7 +149,7 @@ class DatasetEntity:

def __init__(
self,
items: Optional[List[DatasetItemEntity]] = None,
items: Optional[List[TDatasetItemEntity]] = None,
purpose: DatasetPurpose = DatasetPurpose.INFERENCE,
):
self._items = [] if items is None else items
Expand Down Expand Up @@ -270,7 +273,7 @@ def __getitem__(self, key: Union[slice, int]) -> Union["DatasetItemEntity", List
"""
return self._fetch(key)

def __iter__(self) -> Iterator[DatasetItemEntity]:
def __iter__(self) -> Iterator[TDatasetItemEntity]:
"""Return an iterator for the DatasetEntity.
This iterator is able to iterate over the DatasetEntity lazily.
Expand Down Expand Up @@ -308,7 +311,7 @@ def with_empty_annotations(
Returns:
DatasetEntity: a new dataset containing the same items, with empty annotation objects.
"""
new_dataset = DatasetEntity(purpose=self.purpose)
new_dataset = DatasetEntity[DatasetItemEntity](purpose=self.purpose)
for dataset_item in self:
if isinstance(dataset_item, DatasetItemEntity):
empty_annotation = AnnotationSceneEntity(annotations=[], kind=annotation_kind)
Expand Down Expand Up @@ -351,21 +354,21 @@ def get_subset(self, subset: Subset) -> "DatasetEntity":
)
return dataset

def remove(self, item: DatasetItemEntity) -> None:
def remove(self, item: TDatasetItemEntity) -> None:
"""Remove an item from the items.
This function calls remove_at_indices function.
Args:
item (DatasetItemEntity): the item to be deleted.
item (TDatasetItemEntity): the item to be deleted.
Raises:
ValueError: if the input item is not in the dataset
"""
index = self._items.index(item)
self.remove_at_indices([index])

def append(self, item: DatasetItemEntity) -> None:
def append(self, item: TDatasetItemEntity) -> None:
"""Append a DatasetItemEntity to the dataset.
Example:
Expand All @@ -381,7 +384,7 @@ def append(self, item: DatasetItemEntity) -> None:
>>> dataset.append(dataset_item)
Args:
item (DatasetItemEntity): item to append
item (TDatasetItemEntity): item to append
"""

if item.media is None:
Expand Down

0 comments on commit 774bf99

Please sign in to comment.