Skip to content

Commit

Permalink
Refactor API code to base.py in RunInference (#21801)
Browse files Browse the repository at this point in the history
* refactor code from api to base

* delete api.py

* modify imports

* Add todo to mypy github issue

* Refactor code to reflect changes of  #21777

* Refactor example with KeyedModelHandler

* remove explicit type hints from RunInference class

* Fixup : Lint

* remove TODO to github issue for mypy error

* Add mypy github issue as TODO
  • Loading branch information
AnandInguva authored Jun 13, 2022
1 parent 87a7dcc commit 4d04f50
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.api import RunInference
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
Expand Down Expand Up @@ -135,9 +135,7 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True):
lambda file_name, data: (file_name, preprocess_image(data))))
predictions = (
filename_value_pair
|
'PyTorchRunInference' >> RunInference(model_handler).with_output_types(
Tuple[str, PredictionResult])
| 'PyTorchRunInference' >> RunInference(model_handler)
| 'ProcessOutput' >> beam.ParDo(PostProcessor()))

if known_args.output:
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/ml/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from apache_beam.ml.inference.base import RunInference
62 changes: 0 additions & 62 deletions sdks/python/apache_beam/ml/inference/api.py

This file was deleted.

24 changes: 24 additions & 0 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# TODO: https://github.com/apache/beam/issues/21822
# mypy: ignore-errors

"""An extensible run inference transform.
Expand All @@ -32,6 +34,7 @@
import pickle
import sys
import time
from dataclasses import dataclass
from typing import Any
from typing import Generic
from typing import Iterable
Expand All @@ -56,9 +59,17 @@
ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
PredictionT = TypeVar('PredictionT')
_INPUT_TYPE = TypeVar('_INPUT_TYPE')
_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE')
KeyT = TypeVar('KeyT')


@dataclass
class PredictionResult:
example: _INPUT_TYPE
inference: _OUTPUT_TYPE


def _to_milliseconds(time_ns: int) -> int:
return int(time_ns / _NANOSECOND_TO_MILLISECOND)

Expand Down Expand Up @@ -206,6 +217,19 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
Args:
model_handler: An implementation of ModelHandler.
clock: A clock implementing get_current_time_in_microseconds.
A transform that takes a PCollection of examples (or features) to be used on
an ML model. It will then output inferences (or predictions) for those
examples in a PCollection of PredictionResults, containing the input examples
and output inferences.
If examples are paired with keys, it will output a tuple
(key, PredictionResult) for each (key, example) input.
Models for supported frameworks can be loaded via a URI. Supported services
can also be used.
TODO(BEAM-14046): Add and link to help documentation
"""
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult


class PytorchModelHandler(ModelHandler[torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
import torch
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/inference/sklearn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from sklearn.base import BaseEstimator

from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult

try:
import joblib
Expand Down
51 changes: 26 additions & 25 deletions sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
from sklearn.preprocessing import StandardScaler

import apache_beam as beam
from apache_beam.ml.inference import api
from apache_beam.ml.inference import base
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.sklearn_inference import ModelFileType
from apache_beam.ml.inference.sklearn_inference import SklearnModelHandler
from apache_beam.testing.test_pipeline import TestPipeline
Expand Down Expand Up @@ -134,9 +135,9 @@ def test_predict_output(self):
numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9])
]
expected_predictions = [
api.PredictionResult(numpy.array([1, 2, 3]), 6),
api.PredictionResult(numpy.array([4, 5, 6]), 15),
api.PredictionResult(numpy.array([7, 8, 9]), 24)
PredictionResult(numpy.array([1, 2, 3]), 6),
PredictionResult(numpy.array([4, 5, 6]), 15),
PredictionResult(numpy.array([7, 8, 9]), 24)
]
inferences = inference_runner.run_inference(batched_examples, fake_model)
for actual, expected in zip(inferences, expected_predictions):
Expand Down Expand Up @@ -181,11 +182,11 @@ def test_pipeline_pickled(self):

pcoll = pipeline | 'start' >> beam.Create(examples)
#TODO(BEAM-14305) Test against the public API.
actual = pcoll | base.RunInference(
actual = pcoll | RunInference(
SklearnModelHandler(model_uri=temp_file_name))
expected = [
api.PredictionResult(numpy.array([0, 0]), 0),
api.PredictionResult(numpy.array([1, 1]), 1)
PredictionResult(numpy.array([0, 0]), 0),
PredictionResult(numpy.array([1, 1]), 1)
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_prediction_result))
Expand All @@ -201,12 +202,12 @@ def test_pipeline_joblib(self):
pcoll = pipeline | 'start' >> beam.Create(examples)
#TODO(BEAM-14305) Test against the public API.

actual = pcoll | base.RunInference(
actual = pcoll | RunInference(
SklearnModelHandler(
model_uri=temp_file_name, model_file_type=ModelFileType.JOBLIB))
expected = [
api.PredictionResult(numpy.array([0, 0]), 0),
api.PredictionResult(numpy.array([1, 1]), 1)
PredictionResult(numpy.array([0, 0]), 0),
PredictionResult(numpy.array([1, 1]), 1)
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_prediction_result))
Expand All @@ -217,7 +218,7 @@ def test_bad_file_raises(self):
examples = [numpy.array([0, 0])]
pcoll = pipeline | 'start' >> beam.Create(examples)
# TODO(BEAM-14305) Test against the public API.
_ = pcoll | base.RunInference(
_ = pcoll | RunInference(
SklearnModelHandler(model_uri='/var/bad_file_name'))
pipeline.run()

Expand All @@ -239,15 +240,15 @@ def test_pipeline_pandas(self):
dataframe = pandas_dataframe()
splits = [dataframe.loc[[i]] for i in dataframe.index]
pcoll = pipeline | 'start' >> beam.Create(splits)
actual = pcoll | api.RunInference(
actual = pcoll | RunInference(
SklearnModelHandler(model_uri=temp_file_name))

expected = [
api.PredictionResult(splits[0], 5),
api.PredictionResult(splits[1], 8),
api.PredictionResult(splits[2], 1),
api.PredictionResult(splits[3], 1),
api.PredictionResult(splits[4], 2),
PredictionResult(splits[0], 5),
PredictionResult(splits[1], 8),
PredictionResult(splits[2], 1),
PredictionResult(splits[3], 1),
PredictionResult(splits[4], 2),
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))
Expand All @@ -264,14 +265,14 @@ def test_pipeline_pandas_with_keys(self):
keyed_rows = [(key, value) for key, value in zip(keys, splits)]

pcoll = pipeline | 'start' >> beam.Create(keyed_rows)
actual = pcoll | api.RunInference(
base.KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name)))
actual = pcoll | RunInference(
KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name)))
expected = [
('0', api.PredictionResult(splits[0], 5)),
('1', api.PredictionResult(splits[1], 8)),
('2', api.PredictionResult(splits[2], 1)),
('3', api.PredictionResult(splits[3], 1)),
('4', api.PredictionResult(splits[4], 2)),
('0', PredictionResult(splits[0], 5)),
('1', PredictionResult(splits[1], 8)),
('2', PredictionResult(splits[2], 1)),
('3', PredictionResult(splits[3], 1)),
('4', PredictionResult(splits[4], 2)),
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))
Expand Down

0 comments on commit 4d04f50

Please sign in to comment.