Skip to content

Commit

Permalink
refactor code from api to base
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Jun 10, 2022
1 parent ef7cd0c commit 6829459
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 24 deletions.
1 change: 1 addition & 0 deletions sdks/python/apache_beam/ml/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from apache_beam.ml.inference.base import RunInference
28 changes: 28 additions & 0 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mypy: ignore-errors

"""An extensible run inference transform.
Expand All @@ -32,12 +33,15 @@
import pickle
import sys
import time
from dataclasses import dataclass
from typing import Any
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Tuple
from typing import TypeVar
from typing import Union

import apache_beam as beam
from apache_beam.utils import shared
Expand All @@ -54,6 +58,15 @@
ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
PredictionT = TypeVar('PredictionT')
_K = TypeVar('_K')
_INPUT_TYPE = TypeVar('_INPUT_TYPE')
_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE')


@dataclass
class PredictionResult:
example: _INPUT_TYPE
inference: _OUTPUT_TYPE


def _to_milliseconds(time_ns: int) -> int:
Expand Down Expand Up @@ -93,12 +106,27 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]:
return {}


@beam.typehints.with_input_types(Union[_INPUT_TYPE, Tuple[_K, _INPUT_TYPE]])
@beam.typehints.with_output_types(Union[PredictionResult, Tuple[_K, PredictionResult]]) # pylint: disable=line-too-long
class RunInference(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[PredictionT]]):
"""An extensible transform for running inferences.
Args:
model_handler: An implementation of ModelHandler.
clock: A clock implementing get_current_time_in_microseconds.
A transform that takes a PCollection of examples (or features) to be used on
an ML model. It will then output inferences (or predictions) for those
examples in a PCollection of PredictionResults, containing the input examples
and output inferences.
If examples are paired with keys, it will output a tuple
(key, PredictionResult) for each (key, example) input.
Models for supported frameworks can be loaded via a URI. Supported services
can also be used.
TODO(BEAM-14046): Add and link to help documentation
"""
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult


class PytorchModelHandler(ModelHandler[torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class PyTorchInference(unittest.TestCase):
@pytest.mark.uses_pytorch
@pytest.mark.it_postcommit
def test_torch_run_inference_imagenet_mobilenetv2(self):
test_pipeline = TestPipeline(is_integration_test=True)
test_pipeline = TestPipeline(is_integration_test=False)
# text files containing absolute path to the imagenet validation data on GCS
file_of_image_names = 'gs://apache-beam-ml/testing/inputs/it_mobilenetv2_imagenet_validation_inputs.txt' # disable: line-too-long
output_file_dir = 'gs://apache-beam-ml/testing/predictions'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
import torch
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/inference/sklearn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sklearn.base import BaseEstimator

from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import ModelHandler

try:
Expand Down
39 changes: 19 additions & 20 deletions sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from sklearn.preprocessing import StandardScaler

import apache_beam as beam
from apache_beam.ml.inference import api
from apache_beam.ml.inference import base
from apache_beam.ml.inference.sklearn_inference import ModelFileType
from apache_beam.ml.inference.sklearn_inference import SklearnModelHandler
Expand Down Expand Up @@ -134,9 +133,9 @@ def test_predict_output(self):
numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9])
]
expected_predictions = [
api.PredictionResult(numpy.array([1, 2, 3]), 6),
api.PredictionResult(numpy.array([4, 5, 6]), 15),
api.PredictionResult(numpy.array([7, 8, 9]), 24)
base.PredictionResult(numpy.array([1, 2, 3]), 6),
base.PredictionResult(numpy.array([4, 5, 6]), 15),
base.PredictionResult(numpy.array([7, 8, 9]), 24)
]
inferences = inference_runner.run_inference(batched_examples, fake_model)
for actual, expected in zip(inferences, expected_predictions):
Expand Down Expand Up @@ -184,8 +183,8 @@ def test_pipeline_pickled(self):
actual = pcoll | base.RunInference(
SklearnModelHandler(model_uri=temp_file_name))
expected = [
api.PredictionResult(numpy.array([0, 0]), 0),
api.PredictionResult(numpy.array([1, 1]), 1)
base.PredictionResult(numpy.array([0, 0]), 0),
base.PredictionResult(numpy.array([1, 1]), 1)
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_prediction_result))
Expand All @@ -205,8 +204,8 @@ def test_pipeline_joblib(self):
SklearnModelHandler(
model_uri=temp_file_name, model_file_type=ModelFileType.JOBLIB))
expected = [
api.PredictionResult(numpy.array([0, 0]), 0),
api.PredictionResult(numpy.array([1, 1]), 1)
base.PredictionResult(numpy.array([0, 0]), 0),
base.PredictionResult(numpy.array([1, 1]), 1)
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_prediction_result))
Expand Down Expand Up @@ -239,15 +238,15 @@ def test_pipeline_pandas(self):
dataframe = pandas_dataframe()
splits = [dataframe.loc[[i]] for i in dataframe.index]
pcoll = pipeline | 'start' >> beam.Create(splits)
actual = pcoll | api.RunInference(
actual = pcoll | base.RunInference(
SklearnModelHandler(model_uri=temp_file_name))

expected = [
api.PredictionResult(splits[0], 5),
api.PredictionResult(splits[1], 8),
api.PredictionResult(splits[2], 1),
api.PredictionResult(splits[3], 1),
api.PredictionResult(splits[4], 2),
base.PredictionResult(splits[0], 5),
base.PredictionResult(splits[1], 8),
base.PredictionResult(splits[2], 1),
base.PredictionResult(splits[3], 1),
base.PredictionResult(splits[4], 2),
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))
Expand All @@ -264,14 +263,14 @@ def test_pipeline_pandas_with_keys(self):
keyed_rows = [(key, value) for key, value in zip(keys, splits)]

pcoll = pipeline | 'start' >> beam.Create(keyed_rows)
actual = pcoll | api.RunInference(
actual = pcoll | base.RunInference(
SklearnModelHandler(model_uri=temp_file_name))
expected = [
('0', api.PredictionResult(splits[0], 5)),
('1', api.PredictionResult(splits[1], 8)),
('2', api.PredictionResult(splits[2], 1)),
('3', api.PredictionResult(splits[3], 1)),
('4', api.PredictionResult(splits[4], 2)),
('0', base.PredictionResult(splits[0], 5)),
('1', base.PredictionResult(splits[1], 8)),
('2', base.PredictionResult(splits[2], 1)),
('3', base.PredictionResult(splits[3], 1)),
('4', base.PredictionResult(splits[4], 2)),
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))
Expand Down

0 comments on commit 6829459

Please sign in to comment.