Skip to content

Commit

Permalink
[Typing][B-61] Add type annotations for `python/paddle/text/datasets/…
Browse files Browse the repository at this point in the history
…imdb.py` (PaddlePaddle#66037)


---------

Co-authored-by: Nyakku Shigure <[email protected]>
  • Loading branch information
2 people authored and co63oc committed Jul 21, 2024
1 parent 17c0a04 commit c9f70de
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions python/paddle/text/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
import re
import string
import tarfile
from typing import TYPE_CHECKING, Literal

import numpy as np

from paddle.dataset.common import _check_exists_and_download
from paddle.io import Dataset

if TYPE_CHECKING:
from re import Pattern

import numpy.typing as npt

_ImdbDataSetMode = Literal["train", "test"]
__all__ = []

URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
Expand All @@ -33,12 +41,12 @@ class Imdb(Dataset):
Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
data_file(str|None): path to data tar file, can be set None if
:attr:`download` is True. Default None.
mode(str): 'train' 'test' mode. Default 'train'.
cutoff(int): cutoff number for building word dictionary. Default 150.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
:attr:`data_file` is not set. Default True.
Returns:
Dataset: instance of IMDB dataset
Expand Down Expand Up @@ -82,7 +90,19 @@ class Imdb(Dataset):
"""

def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
data_file: str | None
mode: _ImdbDataSetMode
word_idx: dict[str, int]
docs: list
labels: list

def __init__(
self,
data_file: str | None = None,
mode: _ImdbDataSetMode = 'train',
cutoff: int = 150,
download: bool = True,
) -> None:
assert mode.lower() in [
'train',
'test',
Expand All @@ -104,7 +124,7 @@ def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
# read dataset into memory
self._load_anno()

def _build_work_dict(self, cutoff):
def _build_work_dict(self, cutoff: int) -> dict[str, int]:
word_freq = collections.defaultdict(int)
pattern = re.compile(r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$")
for doc in self._tokenize(pattern):
Expand All @@ -120,7 +140,7 @@ def _build_work_dict(self, cutoff):
word_idx['<unk>'] = len(words)
return word_idx

def _tokenize(self, pattern):
def _tokenize(self, pattern: Pattern[str]) -> list[list[str]]:
data = []
with tarfile.open(self.data_file) as tarf:
tf = tarf.next()
Expand All @@ -139,7 +159,7 @@ def _tokenize(self, pattern):

return data

def _load_anno(self):
def _load_anno(self) -> None:
pos_pattern = re.compile(fr"aclImdb/{self.mode}/pos/.*\.txt$")
neg_pattern = re.compile(fr"aclImdb/{self.mode}/neg/.*\.txt$")

Expand All @@ -154,8 +174,10 @@ def _load_anno(self):
self.docs.append([self.word_idx.get(w, UNK) for w in doc])
self.labels.append(1)

def __getitem__(self, idx):
def __getitem__(
self, idx: int
) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]:
return (np.array(self.docs[idx]), np.array([self.labels[idx]]))

def __len__(self):
def __len__(self) -> int:
return len(self.docs)

0 comments on commit c9f70de

Please sign in to comment.