Skip to content

Commit

Permalink
fix mypy error
Browse files Browse the repository at this point in the history
  • Loading branch information
qinzzz committed Feb 19, 2022
1 parent ca63c5b commit f6ca640
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions forte/data/converter/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def convert(
masks_tensor_list: List[torch.Tensor] = []
for batch_masks_dim_i in masks_list:
masks_tensor_list.append(
self._to_tensor_type(batch_masks_dim_i, np.bool)
self._to_tensor_type(batch_masks_dim_i, np.bool_)
)

return data_tensor, masks_tensor_list
Expand All @@ -293,7 +293,7 @@ def convert(
masks_np_list: List[np.ndarray] = []
for batch_masks_dim_i in masks_list:
masks_np_list.append(
self._to_numpy_type(batch_masks_dim_i, np.bool)
self._to_numpy_type(batch_masks_dim_i, np.bool_)
)
return data_np, masks_np_list

Expand Down
2 changes: 1 addition & 1 deletion forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def get_span_text(self, begin: int, end: int) -> str:
"""
return self._text[begin:end]

def get_span_audio(self, begin: int, end: int) -> str:
def get_span_audio(self, begin: int, end: int) -> np.ndarray:
r"""Get the audio in the data pack contained in the span.
`begin` and `end` represent the starting and ending indices of the span
in audio payload respectively. Each index corresponds to one sample in
Expand Down
2 changes: 1 addition & 1 deletion forte/data/ontology/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ class FNdArray:
"""

def __init__(
self, dtype: Optional[str] = None, shape: Optional[Iterable[int]] = None
self, dtype: Optional[str] = None, shape: Optional[List[int]] = None
):
super().__init__()
self._dtype: Optional[np.dtype] = (
Expand Down
3 changes: 2 additions & 1 deletion forte/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Dict, List, Tuple, Type, Union

import torch
import numpy as np

from forte.data.ontology.core import Entry
from forte.data.span import Span
Expand All @@ -25,4 +26,4 @@

DataRequest = Dict[Type[Entry], Union[Dict, List]]

MatrixLike = Union[torch.TensorType, List]
MatrixLike = Union[torch.TensorType, np.ndarray, List]
2 changes: 1 addition & 1 deletion forte/processors/base/batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _process_batch(
pred = self.predict(batched_data)
self.pack_all(packs, contexts, pred)

def predict(self, data_batch: Dict) -> Dict[str, List[Any]]:
def predict(self, data_batch: Dict) -> Dict[str, Dict[str, List[Any]]]:
r"""The function that task processors should implement. Make
predictions for the input ``data_batch``.
Expand Down
2 changes: 1 addition & 1 deletion forte/processors/ir/bert_based_query_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _build_query(self, text: str) -> np.ndarray:

def _process_query(
self, input_pack: MultiPack
) -> Tuple[DataPack, Dict[str, Any]]:
) -> Tuple[DataPack, np.ndarray[Any, Any]]:
query_pack: DataPack = input_pack.get_pack(self.config.query_pack_name)
context = [query_pack.text]

Expand Down
2 changes: 1 addition & 1 deletion forte/processors/nlp/ner_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def load_model(path):
@torch.no_grad()
def predict(
self, data_batch: Dict[str, Dict[str, List[str]]]
) -> Dict[str, Dict[str, List[np.array]]]:
) -> Dict[str, Dict[str, List[np.ndarray]]]:
tokens = data_batch["Token"]

instances = []
Expand Down

0 comments on commit f6ca640

Please sign in to comment.