Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make keying of examples explicit. #21777

Merged
merged 9 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 102 additions & 20 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
from typing import Any
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Sequence
from typing import Tuple
from typing import TypeVar

import apache_beam as beam
Expand All @@ -54,6 +55,7 @@
ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
PredictionT = TypeVar('PredictionT')
KeyT = TypeVar('KeyT')


def _to_milliseconds(time_ns: int) -> int:
Expand All @@ -70,13 +72,13 @@ def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
raise NotImplementedError(type(self))

def run_inference(self, batch: List[ExampleT], model: ModelT,
def run_inference(self, batch: Sequence[ExampleT], model: ModelT,
**kwargs) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples and
returns an Iterable of Predictions."""
raise NotImplementedError(type(self))

def get_num_bytes(self, batch: List[ExampleT]) -> int:
def get_num_bytes(self, batch: Sequence[ExampleT]) -> int:
"""Returns the number of bytes of data for a batch."""
return len(pickle.dumps(batch))

Expand All @@ -93,6 +95,100 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]:
return {}


class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's any way to make this more readable or simple. All these nested lists are making my eyes a little buggy.

Can we perhaps use constants here?
BASIC_MODEL_HANDLER = ModelHandler[Tuple[KeyT, ExampleT]
KEYED_PREDICTION = Tuple[KeyT, PredictionT]

Or are there any other ways to make some of this templating go away?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played around with this a bit, don't see a way to really make things much simpler. It is possible delete the generic, but then it becomes harder to reason about the order of the nested arguments.

ModelHandler[Tuple[KeyT, ExampleT],
Tuple[KeyT, PredictionT],
ModelT]):
"""A ModelHandler that takes keyed examples and returns keyed predictions.

For example, if the original model was used with RunInference to take a
PCollection[E] to a PCollection[P], this would take a
PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], allowing one to
associate the outputs with the inputs based on the key.
"""
def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
self._unkeyed = unkeyed

def load_model(self) -> ModelT:
return self._unkeyed.load_model()

def run_inference(
self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT,
**kwargs) -> Iterable[Tuple[KeyT, PredictionT]]:
keys, unkeyed_batch = zip(*batch)
return zip(
keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs))

def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
keys, unkeyed_batch = zip(*batch)
return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)

def get_metrics_namespace(self) -> str:
return self._unkeyed.get_metrics_namespace()

def get_resource_hints(self):
return self._unkeyed.get_resource_hints()

def batch_elements_kwargs(self):
return self._unkeyed.batch_elements_kwargs()
return {}


class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Tuple[KeyT, ExampleT],
Tuple[KeyT, PredictionT],
ModelT]):
"""A ModelHandler that takes possibly keyed examples and returns possibly
keyed predictions.

For example, if the original model was used with RunInference to take a
PCollection[E] to a PCollection[P], this would take either PCollection[E] to a
PCollection[P] or PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]],
depending on the whether the elements happen to be tuples, allowing one to
associate the outputs with the inputs based on the key.

Note that this cannot be used if E happens to be a tuple type.
"""
def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
self._unkeyed = unkeyed

def load_model(self) -> ModelT:
return self._unkeyed.load_model()

def run_inference(
self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT,
**kwargs) -> Iterable[Tuple[KeyT, PredictionT]]:
if isinstance(batch[0], tuple):
is_keyed = True
keys, unkeyed_batch = zip(*batch)
else:
is_keyed = False
unkeyed_batch = batch
unkeyed_results = self._unkeyed.run_inference(
unkeyed_batch, model, **kwargs)
if is_keyed:
return zip(keys, unkeyed_results)
else:
return unkeyed_results

def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
if isinstance(batch[0], tuple):
keys, unkeyed_batch = zip(*batch)
return len(
pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
else:
return self._unkeyed.get_num_bytes(batch)

def get_metrics_namespace(self) -> str:
return self._unkeyed.get_metrics_namespace()

def get_resource_hints(self):
return self._unkeyed.get_resource_hints()

def batch_elements_kwargs(self):
return self._unkeyed.batch_elements_kwargs()


class RunInference(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[PredictionT]]):
"""An extensible transform for running inferences.
Expand Down Expand Up @@ -205,32 +301,18 @@ def setup(self):
self._model = self._load_model()

def process(self, batch, **kwargs):
# Process supports both keyed data, and example only data.
# First keys and samples are separated (if there are keys)
has_keys = isinstance(batch[0], tuple)
if has_keys:
examples = [example for _, example in batch]
keys = [key for key, _ in batch]
else:
examples = batch
keys = None

start_time = _to_microseconds(self._clock.time_ns())
result_generator = self._model_handler.run_inference(
examples, self._model, **kwargs)
batch, self._model, **kwargs)
predictions = list(result_generator)

end_time = _to_microseconds(self._clock.time_ns())
inference_latency = end_time - start_time
num_bytes = self._model_handler.get_num_bytes(examples)
num_bytes = self._model_handler.get_num_bytes(batch)
num_elements = len(batch)
self._metrics_collector.update(num_elements, num_bytes, inference_latency)

# Keys are recombined with predictions in the RunInference PTransform.
if has_keys:
yield from zip(keys, predictions)
else:
yield from predictions
return predictions

def finish_bundle(self):
# TODO(BEAM-13970): Figure out why there is a cache.
Expand Down
20 changes: 19 additions & 1 deletion sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,27 @@ def test_run_inference_impl_with_keyed_examples(self):
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [(i, example + 1) for i, example in enumerate(examples)]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
actual = pcoll | base.RunInference(FakeModelHandler())
actual = pcoll | base.RunInference(
base.KeyedModelHandler(FakeModelHandler()))
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_run_inference_impl_with_maybe_keyed_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [example + 1 for example in examples]
keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())

pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
assert_that(actual, equal_to(expected), label='CheckUnkeyed')

keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
model_handler)
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')

def test_run_inference_impl_kwargs(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
Expand Down
6 changes: 3 additions & 3 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Sequence
from typing import Union

import torch
Expand Down Expand Up @@ -87,7 +87,7 @@ def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:

def run_inference(
self,
batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]],
model: torch.nn.Module,
**kwargs) -> Iterable[PredictionResult]:
"""
Expand Down Expand Up @@ -119,7 +119,7 @@ def run_inference(
predictions = model(batched_tensors, **prediction_params)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the same Sequence change for sklearn_inference.py lines 78, 89, 97, 116, 121

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""Returns the number of bytes of data for a batch of Tensors."""
# If elements in `batch` are provided as a dictionaries from key to Tensors
if isinstance(batch[0], dict):
Expand Down