Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb committed Jun 9, 2022
1 parent a4d253e commit 31c7788
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from typing import Any
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Sequence
from typing import Tuple
from typing import TypeVar

Expand Down Expand Up @@ -68,13 +68,13 @@ def _to_microseconds(time_ns: int) -> int:

class InferenceRunner(Generic[ExampleT, PredictionT, ModelT]):
"""Implements running inferences for a framework."""
def run_inference(self, batch: List[ExampleT], model: ModelT,
def run_inference(self, batch: Sequence[ExampleT], model: ModelT,
**kwargs) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples and
returns an Iterable of Predictions."""
raise NotImplementedError(type(self))

def get_num_bytes(self, batch: List[ExampleT]) -> int:
def get_num_bytes(self, batch: Sequence[ExampleT]) -> int:
"""Returns the number of bytes of data for a batch."""
return len(pickle.dumps(batch))

Expand Down Expand Up @@ -131,13 +131,13 @@ def __init__(self, unkeyed: InferenceRunner[ExampleT, PredictionT, ModelT]):
self._unkeyed = unkeyed

def run_inference(
self, batch: List[Tuple[KeyT, ExampleT]], model: ModelT,
self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT,
**kwargs) -> Iterable[Tuple[KeyT, PredictionT]]:
keys, unkeyed_batch = zip(*batch)
return zip(
keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs))

def get_num_bytes(self, batch: List[Tuple[KeyT, ExampleT]]) -> int:
def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
keys, unkeyed_batch = zip(*batch)
return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)

Expand Down
6 changes: 3 additions & 3 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Sequence
from typing import Union

import torch
Expand Down Expand Up @@ -56,7 +56,7 @@ def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:

def run_inference(
self,
batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]],
model: torch.nn.Module,
**kwargs) -> Iterable[PredictionResult]:
"""
Expand Down Expand Up @@ -88,7 +88,7 @@ def run_inference(
predictions = model(batched_tensors, **prediction_params)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""Returns the number of bytes of data for a batch of Tensors."""
# If elements in `batch` are provided as a dictionaries from key to Tensors
if isinstance(batch[0], dict):
Expand Down

0 comments on commit 31c7788

Please sign in to comment.