Skip to content

Commit

Permalink
modify imports
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Jun 10, 2022
1 parent 8235a5b commit 5ecbf64
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from sklearn.preprocessing import StandardScaler

import apache_beam as beam
from apache_beam.ml.inference import base
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.sklearn_inference import ModelFileType
from apache_beam.ml.inference.sklearn_inference import SklearnModelHandler
from apache_beam.testing.test_pipeline import TestPipeline
Expand Down Expand Up @@ -133,9 +134,9 @@ def test_predict_output(self):
numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9])
]
expected_predictions = [
base.PredictionResult(numpy.array([1, 2, 3]), 6),
base.PredictionResult(numpy.array([4, 5, 6]), 15),
base.PredictionResult(numpy.array([7, 8, 9]), 24)
PredictionResult(numpy.array([1, 2, 3]), 6),
PredictionResult(numpy.array([4, 5, 6]), 15),
PredictionResult(numpy.array([7, 8, 9]), 24)
]
inferences = inference_runner.run_inference(batched_examples, fake_model)
for actual, expected in zip(inferences, expected_predictions):
Expand Down Expand Up @@ -180,11 +181,11 @@ def test_pipeline_pickled(self):

pcoll = pipeline | 'start' >> beam.Create(examples)
#TODO(BEAM-14305) Test against the public API.
actual = pcoll | base.RunInference(
actual = pcoll | RunInference(
SklearnModelHandler(model_uri=temp_file_name))
expected = [
base.PredictionResult(numpy.array([0, 0]), 0),
base.PredictionResult(numpy.array([1, 1]), 1)
PredictionResult(numpy.array([0, 0]), 0),
PredictionResult(numpy.array([1, 1]), 1)
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_prediction_result))
Expand All @@ -200,12 +201,12 @@ def test_pipeline_joblib(self):
pcoll = pipeline | 'start' >> beam.Create(examples)
#TODO(BEAM-14305) Test against the public API.

actual = pcoll | base.RunInference(
actual = pcoll | RunInference(
SklearnModelHandler(
model_uri=temp_file_name, model_file_type=ModelFileType.JOBLIB))
expected = [
base.PredictionResult(numpy.array([0, 0]), 0),
base.PredictionResult(numpy.array([1, 1]), 1)
PredictionResult(numpy.array([0, 0]), 0),
PredictionResult(numpy.array([1, 1]), 1)
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_prediction_result))
Expand All @@ -216,7 +217,7 @@ def test_bad_file_raises(self):
examples = [numpy.array([0, 0])]
pcoll = pipeline | 'start' >> beam.Create(examples)
# TODO(BEAM-14305) Test against the public API.
_ = pcoll | base.RunInference(
_ = pcoll | RunInference(
SklearnModelHandler(model_uri='/var/bad_file_name'))
pipeline.run()

Expand All @@ -238,15 +239,15 @@ def test_pipeline_pandas(self):
dataframe = pandas_dataframe()
splits = [dataframe.loc[[i]] for i in dataframe.index]
pcoll = pipeline | 'start' >> beam.Create(splits)
actual = pcoll | base.RunInference(
actual = pcoll | RunInference(
SklearnModelHandler(model_uri=temp_file_name))

expected = [
base.PredictionResult(splits[0], 5),
base.PredictionResult(splits[1], 8),
base.PredictionResult(splits[2], 1),
base.PredictionResult(splits[3], 1),
base.PredictionResult(splits[4], 2),
PredictionResult(splits[0], 5),
PredictionResult(splits[1], 8),
PredictionResult(splits[2], 1),
PredictionResult(splits[3], 1),
PredictionResult(splits[4], 2),
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))
Expand All @@ -263,14 +264,14 @@ def test_pipeline_pandas_with_keys(self):
keyed_rows = [(key, value) for key, value in zip(keys, splits)]

pcoll = pipeline | 'start' >> beam.Create(keyed_rows)
actual = pcoll | base.RunInference(
actual = pcoll | RunInference(
SklearnModelHandler(model_uri=temp_file_name))
expected = [
('0', base.PredictionResult(splits[0], 5)),
('1', base.PredictionResult(splits[1], 8)),
('2', base.PredictionResult(splits[2], 1)),
('3', base.PredictionResult(splits[3], 1)),
('4', base.PredictionResult(splits[4], 2)),
('0', PredictionResult(splits[0], 5)),
('1', PredictionResult(splits[1], 8)),
('2', PredictionResult(splits[2], 1)),
('3', PredictionResult(splits[3], 1)),
('4', PredictionResult(splits[4], 2)),
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))
Expand Down

0 comments on commit 5ecbf64

Please sign in to comment.