Skip to content

Commit

Permalink
Merge pull request #21805 from AnandInguva/sklearn-tests-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Jun 10, 2022
2 parents a079e97 + 8c3d322 commit ef7cd0c
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_pipeline_pandas(self):
splits = [dataframe.loc[[i]] for i in dataframe.index]
pcoll = pipeline | 'start' >> beam.Create(splits)
actual = pcoll | api.RunInference(
SklearnModelLoader(model_uri=temp_file_name))
SklearnModelHandler(model_uri=temp_file_name))

expected = [
api.PredictionResult(splits[0], 5),
Expand All @@ -265,7 +265,7 @@ def test_pipeline_pandas_with_keys(self):

pcoll = pipeline | 'start' >> beam.Create(keyed_rows)
actual = pcoll | api.RunInference(
SklearnModelLoader(model_uri=temp_file_name))
SklearnModelHandler(model_uri=temp_file_name))
expected = [
('0', api.PredictionResult(splits[0], 5)),
('1', api.PredictionResult(splits[1], 8)),
Expand All @@ -279,14 +279,14 @@ def test_pipeline_pandas_with_keys(self):
def test_infer_invalid_data_type(self):
with self.assertRaises(ValueError):
unexpected_input_type = [[1, 2, 3, 4], [5, 6, 7, 8]]
inference_runner = SklearnModelLoader(model_uri=unused)
inference_runner = SklearnModelHandler(model_uri='unused')
fake_model = FakeModel()
inference_runner.run_inference(unexpected_input_type, fake_model)

def test_infer_too_many_rows_in_dataframe(self):
with self.assertRaises(ValueError):
data_frame_too_many_rows = pandas_dataframe()
inference_runner = SklearnModelLoader(model_uri=unused)
inference_runner = SklearnModelHandler(model_uri='unused')
fake_model = FakeModel()
inference_runner.run_inference([data_frame_too_many_rows], fake_model)

Expand Down

0 comments on commit ef7cd0c

Please sign in to comment.