From c29ef8a47616b90d62a024574c58613d345bbe28 Mon Sep 17 00:00:00 2001 From: Yang Pan Date: Mon, 29 Nov 2021 12:37:14 -0800 Subject: [PATCH] feat(component): add code for creating and monitoring bq job PiperOrigin-RevId: 412956806 --- .../bigquery_query_job_remote_runner.py | 171 ++++++++++++ .../experimental/gcp_launcher/launcher.py | 6 +- .../test_bigquery_query_job_remote_runner.py | 251 +++++++++++++++++ .../gcp_launcher/test_launcher.py | 259 ++++++++++-------- 4 files changed, 573 insertions(+), 114 deletions(-) create mode 100644 components/google-cloud/google_cloud_pipeline_components/container/experimental/gcp_launcher/bigquery_query_job_remote_runner.py create mode 100644 components/google-cloud/tests/container/experimental/gcp_launcher/test_bigquery_query_job_remote_runner.py diff --git a/components/google-cloud/google_cloud_pipeline_components/container/experimental/gcp_launcher/bigquery_query_job_remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/experimental/gcp_launcher/bigquery_query_job_remote_runner.py new file mode 100644 index 000000000000..c4fcfb5bd697 --- /dev/null +++ b/components/google-cloud/google_cloud_pipeline_components/container/experimental/gcp_launcher/bigquery_query_job_remote_runner.py @@ -0,0 +1,171 @@ +# Copyright 2021 The Kubeflow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import json +import re +import os +import requests +import time +import google.auth +import google.auth.transport.requests + +from .utils import json_util +from .utils import artifact_util +from google.cloud import bigquery +from google.protobuf import json_format +from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources +from os import path +from typing import Optional + +_POLLING_INTERVAL_IN_SECONDS = 20 +_BQ_JOB_NAME_TEMPLATE = r'(https://www.googleapis.com/bigquery/v2/projects/(?P.*)/jobs/(?P.*)\?location=(?P.*))' + + +def check_if_job_exists(gcp_resources) -> Optional[str]: + """Check if the BigQuery job already created. + + Return the job url if created. Return None otherwise + """ + if path.exists(gcp_resources) and os.stat(gcp_resources).st_size != 0: + with open(gcp_resources) as f: + serialized_gcp_resources = f.read() + job_resources = json_format.Parse(serialized_gcp_resources, + GcpResources()) + # Resources should only contain one item. + if len(job_resources.resources) != 1: + raise ValueError( + f'gcp_resources should contain one resource, found {len(job_resources.resources)}' + ) + # Validate the format of the resource uri. + job_name_pattern = re.compile(_BQ_JOB_NAME_TEMPLATE) + match = job_name_pattern.match(job_resources.resources[0].resource_uri) + try: + project = match.group('project') + job = match.group('job') + except AttributeError as err: + raise ValueError('Invalid bigquery job uri: {}. Expect: {}.'.format( + job_resources.resources[0].resource_uri, + 'https://www.googleapis.com/bigquery/v2/projects/[projectId]/jobs/[jobId]?location=[location]' + )) + + return job_resources.resources[0].resource_uri + else: + return None + + +def create_job(job_type, project, location, payload, creds, + gcp_resources) -> str: + """Create a new BigQuery job""" + job_configuration = json.loads(payload, strict=False) + # Always use standard SQL instead of legacy SQL. + job_configuration['query']['useLegacySql'] = False + job_request = { + # TODO(IronPan) temporarily remove the empty fields from the spec + 'configuration': json_util.recursive_remove_empty(job_configuration), + } + if location is not None: + if 'jobReference' not in job_request: + job_request['jobReference'] = {} + job_request['jobReference']['location'] = location + + creds.refresh(google.auth.transport.requests.Request()) + headers = { + 'Content-type': 'application/json', + 'Authorization': 'Bearer ' + creds.token, + 'User-Agent': 'google-cloud-pipeline-components' + } + insert_job_url = f'https://www.googleapis.com/bigquery/v2/projects/{project}/jobs' + job = requests.post( + url=insert_job_url, data=json.dumps(job_request), headers=headers).json() + if 'selfLink' not in job: + raise RuntimeError( + 'BigQquery Job failed. Cannot retrieve the job name. Response: {}.' + .format(job)) + + # Write the bigquey job uri to gcp resource. + job_uri = job['selfLink'] + job_resources = GcpResources() + job_resource = job_resources.resources.add() + job_resource.resource_type = job_type + job_resource.resource_uri = job_uri + with open(gcp_resources, 'w') as f: + f.write(json_format.MessageToJson(job_resources)) + + return job_uri + + +def poll_job(job_uri, creds) -> dict: + """Poll the bigquery job till it reaches a final state.""" + job = {} + while ('status' not in job) or ('state' not in job['status']) or ( + job['status']['state'].lower() != 'done'): + time.sleep(_POLLING_INTERVAL_IN_SECONDS) + logging.info('The job is running...') + if not creds.valid: + creds.refresh(google.auth.transport.requests.Request()) + headers = { + 'Content-type': 'application/json', + 'Authorization': 'Bearer ' + creds.token + } + job = requests.get(job_uri, headers=headers).json() + if 'status' in job and 'errorResult' in job['status']: + raise RuntimeError('The BigQuery job failed. Error: {}'.format( + job['status'])) + + logging.info('BigQuery Job completed succesesfully. Job: %s.', job) + return job + + +def create_bigquery_job( + type, + project, + location, + payload, + gcp_resources, + executor_input, +): + """Create and poll bigquery job status till it reaches a final state. + + This follows the typical launching logic: + 1. Read if the bigquery job already exists in gcp_resources + - If already exists, jump to step 3 and poll the job status. This happens + if the launcher container experienced unexpected termination, such as + preemption + 2. Deserialize the payload into the job spec and create the bigquery job + 3. Poll the bigquery job status every + job_remote_runner._POLLING_INTERVAL_IN_SECONDS seconds + - If the bigquery job is succeeded, return succeeded + - If the bigquery job is pending/running, continue polling the status + + Also retry on ConnectionError up to + job_remote_runner._CONNECTION_ERROR_RETRY_LIMIT times during the poll. + """ + creds, _ = google.auth.default() + job_uri = check_if_job_exists(gcp_resources) + if job_uri is None: + job_uri = create_job(type, project, location, payload, creds, gcp_resources) + + # Poll bigquery job status until finished. + job = poll_job(job_uri, creds) + + # write destination_table output artifact + if 'destinationTable' in job['configuration']['query']: + projectId = job['configuration']['query']['destinationTable']['projectId'] + datasetId = job['configuration']['query']['destinationTable']['datasetId'] + tableId = job['configuration']['query']['destinationTable']['tableId'] + artifact_util.update_output_artifact( + executor_input, 'destinationTable', + f'https://www.googleapis.com/bigquery/v2/projects/{projectId}/datasets/{datasetId}/tables/{tableId}', + {}) diff --git a/components/google-cloud/google_cloud_pipeline_components/container/experimental/gcp_launcher/launcher.py b/components/google-cloud/google_cloud_pipeline_components/container/experimental/gcp_launcher/launcher.py index 8d937c72056e..6e40a834d0d7 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/experimental/gcp_launcher/launcher.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/experimental/gcp_launcher/launcher.py @@ -22,6 +22,7 @@ from . import upload_model_remote_runner from . import export_model_remote_runner from . import deploy_model_remote_runner +from . import bigquery_query_job_remote_runner from . import wait_gcp_resources @@ -80,7 +81,8 @@ def _parse_args(args): # executor_input is only needed for components that emit output artifacts. required=(parsed_args.type == 'UploadModel' or parsed_args.type == 'CreateEndpoint' or - parsed_args.type == 'BatchPredictionJob'), + parsed_args.type == 'BatchPredictionJob' or + parsed_args.type == 'BigqueryQueryJob'), default=argparse.SUPPRESS) parser.add_argument( "--output_info", @@ -127,6 +129,8 @@ def main(argv): export_model_remote_runner.export_model(**parsed_args) if parsed_args['type'] == 'DeployModel': deploy_model_remote_runner.deploy_model(**parsed_args) + if parsed_args['type'] == 'BigqueryQueryJob': + bigquery_query_job_remote_runner.create_bigquery_job(**parsed_args) if parsed_args['type'] == 'Wait': wait_gcp_resources.wait_gcp_resources(**parsed_args) diff --git a/components/google-cloud/tests/container/experimental/gcp_launcher/test_bigquery_query_job_remote_runner.py b/components/google-cloud/tests/container/experimental/gcp_launcher/test_bigquery_query_job_remote_runner.py new file mode 100644 index 000000000000..260410dbf72b --- /dev/null +++ b/components/google-cloud/tests/container/experimental/gcp_launcher/test_bigquery_query_job_remote_runner.py @@ -0,0 +1,251 @@ +# Copyright 2021 The Kubeflow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test BigQuery Query Job Remote Runner module.""" + +import json +from logging import raiseExceptions +import os +import time +import unittest +from unittest import mock +import requests +import google.auth +import google.auth.transport.requests + +from google.protobuf import json_format +from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources +from google_cloud_pipeline_components.container.experimental.gcp_launcher import bigquery_query_job_remote_runner +from google_cloud_pipeline_components.container.experimental.gcp_launcher import job_remote_runner + + +class BigqueryQueryJobRemoteRunnerUtilsTests(unittest.TestCase): + + def setUp(self): + super(BigqueryQueryJobRemoteRunnerUtilsTests, self).setUp() + self._payload = ( + '{"query": {"query": "CREATE OR REPLACE MODEL ' + 'bqml_tutorial.penguins_model OPTIONS (model_type=\'linear_reg\', ' + 'input_label_cols=[\'body_mass_g\']) AS SELECT * FROM ' + '`bigquery-public-data.ml_datasets.penguins` WHERE body_mass_g IS NOT ' + 'NULL"}}') + self._job_type = 'BigqueryQueryJob' + self._project = 'test_project' + self._location = 'US' + self._job_uri = 'https://www.googleapis.com/bigquery/v2/projects/test_project/jobs/fake_job?location=US' + self._gcp_resources = os.path.join( + os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), 'gcp_resources') + self._output_file_path = os.path.join( + os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), 'localpath/foo') + self._executor_input = '{"outputs":{"artifacts":{"destinationTable":{"artifacts":[{"metadata":{},"name":"foobar","type":{"schemaTitle":"google.BQTable"}}]}},"outputFile":"' + self._output_file_path + '"}}' + + def tearDown(self): + if os.path.exists(self._gcp_resources): + os.remove(self._gcp_resources) + + @mock.patch.object(google.auth, 'default', autospec=True) + @mock.patch.object(google.auth.transport.requests, 'Request', autospec=True) + @mock.patch.object(requests, 'post', autospec=True) + @mock.patch.object(requests, 'get', autospec=True) + @mock.patch.object(time, 'sleep', autospec=True) + def test_bigquery_query_job_remote_runner_succeeded(self, mock_time_sleep, + mock_get_requests, + mock_post_requests, _, + mock_auth): + creds = mock.Mock() + creds.token = 'fake_token' + mock_auth.return_value = [creds, 'project'] + mock_created_bq_job = mock.Mock() + mock_created_bq_job.json.return_value = {'selfLink': self._job_uri} + mock_post_requests.return_value = mock_created_bq_job + + mock_polled_bq_job = mock.Mock() + mock_polled_bq_job.json.return_value = { + 'selfLink': self._job_uri, + 'status': { + 'state': 'DONE' + }, + 'configuration': { + 'query': { + 'destinationTable': { + 'projectId': 'test_project', + 'datasetId': 'test_dataset', + 'tableId': 'test_table' + } + } + } + } + mock_get_requests.return_value = mock_polled_bq_job + + bigquery_query_job_remote_runner.create_bigquery_job( + self._job_type, self._project, self._location, self._payload, + self._gcp_resources, self._executor_input) + mock_post_requests.assert_called_once_with( + url=f'https://www.googleapis.com/bigquery/v2/projects/{self._project}/jobs', + data=( + '{"configuration": {"query": {"query": "CREATE OR REPLACE MODEL bqml_tutorial.penguins_model OPTIONS (model_type=\'linear_reg\', input_label_cols=[\'body_mass_g\']) AS SELECT * FROM `bigquery-public-data.ml_datasets.penguins` WHERE body_mass_g IS NOT NULL"}}, "jobReference": {"location": "US"}}' + ), + headers={ + 'Content-type': 'application/json', + 'Authorization': 'Bearer fake_token', + 'User-Agent': 'google-cloud-pipeline-components' + }) + + with open(self._output_file_path) as f: + self.assertEqual( + f.read(), + '{"artifacts": {"destinationTable": {"artifacts": [{"metadata": {}, "name": "foobar", "type": {"schemaTitle": "google.BQTable"}, "uri": "https://www.googleapis.com/bigquery/v2/projects/test_project/datasets/test_dataset/tables/test_table"}]}}}' + ) + + with open(self._gcp_resources) as f: + serialized_gcp_resources = f.read() + # Instantiate GCPResources Proto + bq_job_resources = json_format.Parse(serialized_gcp_resources, + GcpResources()) + self.assertEqual(len(bq_job_resources.resources), 1) + self.assertEqual( + bq_job_resources.resources[0].resource_uri, + 'https://www.googleapis.com/bigquery/v2/projects/test_project/jobs/fake_job?location=US' + ) + + self.assertEqual(mock_post_requests.call_count, 1) + self.assertEqual(mock_time_sleep.call_count, 1) + self.assertEqual(mock_get_requests.call_count, 1) + + @mock.patch.object(google.auth, 'default', autospec=True) + @mock.patch.object(google.auth.transport.requests, 'Request', autospec=True) + @mock.patch.object(requests, 'get', autospec=True) + @mock.patch.object(time, 'sleep', autospec=True) + def test_bigquery_query_job_remote_runner_poll_existing_job_succeeded( + self, mock_time_sleep, mock_get_requests, _, mock_auth): + # Mimic the case that self._gcp_resources already stores the job uri. + with open(self._gcp_resources, 'w') as f: + f.write( + '{"resources": [{"resourceType": "BigqueryQueryJob", "resourceUri": "https://www.googleapis.com/bigquery/v2/projects/test_project/jobs/fake_job?location=US"}]}' + ) + + creds = mock.Mock() + creds.token = 'fake_token' + mock_auth.return_value = [creds, 'project'] + + mock_polled_bq_job = mock.Mock() + mock_polled_bq_job.json.return_value = { + 'selfLink': self._job_uri, + 'status': { + 'state': 'DONE' + }, + 'configuration': { + 'query': { + 'destinationTable': { + 'projectId': 'test_project', + 'datasetId': 'test_dataset', + 'tableId': 'test_table' + } + } + } + } + mock_get_requests.return_value = mock_polled_bq_job + + bigquery_query_job_remote_runner.create_bigquery_job( + self._job_type, self._project, self._location, self._payload, + self._gcp_resources, self._executor_input) + + with open(self._output_file_path) as f: + self.assertEqual( + f.read(), + '{"artifacts": {"destinationTable": {"artifacts": [{"metadata": {}, "name": "foobar", "type": {"schemaTitle": "google.BQTable"}, "uri": "https://www.googleapis.com/bigquery/v2/projects/test_project/datasets/test_dataset/tables/test_table"}]}}}' + ) + + self.assertEqual(mock_time_sleep.call_count, 1) + self.assertEqual(mock_get_requests.call_count, 1) + + @mock.patch.object(google.auth, 'default', autospec=True) + @mock.patch.object(google.auth.transport.requests, 'Request', autospec=True) + def test_bigquery_query_job_remote_runner_check_job_exists_wrong_format( + self, _, mock_auth): + # Mimic the case that self._gcp_resources already stores the job uri. + with open(self._gcp_resources, 'w') as f: + f.write( + '{"resources": [{"resourceType": "BigqueryQueryJob", "resourceUri": "https://www.googleapis.com/bigquery/v2/projects/test_project/jobs/fake_job_no_location"}]}' + ) + + creds = mock.Mock() + creds.token = 'fake_token' + mock_auth.return_value = [creds, 'project'] + + with self.assertRaises(ValueError): + bigquery_query_job_remote_runner.create_bigquery_job( + self._job_type, self._project, self._location, self._payload, + self._gcp_resources, self._executor_input) + + @mock.patch.object(google.auth, 'default', autospec=True) + @mock.patch.object(google.auth.transport.requests, 'Request', autospec=True) + @mock.patch.object(requests, 'post', autospec=True) + def test_bigquery_query_job_remote_runner_failed_no_selflink( + self, mock_post_requests, _, mock_auth): + creds = mock.Mock() + creds.token = 'fake_token' + mock_auth.return_value = [creds, 'project'] + mock_created_bq_job = mock.Mock() + mock_created_bq_job.json.return_value = {} + mock_post_requests.return_value = mock_created_bq_job + + with self.assertRaises(RuntimeError): + bigquery_query_job_remote_runner.create_bigquery_job( + self._job_type, self._project, self._location, self._payload, + self._gcp_resources, self._executor_input) + + @mock.patch.object(google.auth, 'default', autospec=True) + @mock.patch.object(google.auth.transport.requests, 'Request', autospec=True) + @mock.patch.object(requests, 'get', autospec=True) + @mock.patch.object(time, 'sleep', autospec=True) + def test_bigquery_query_job_remote_runner_poll_existing_job_failed( + self, mock_time_sleep, mock_get_requests, _, mock_auth): + # Mimic the case that self._gcp_resources already stores the job uri. + with open(self._gcp_resources, 'w') as f: + f.write( + '{"resources": [{"resourceType": "BigqueryQueryJob", "resourceUri": "https://www.googleapis.com/bigquery/v2/projects/test_project/jobs/fake_job?location=US"}]}' + ) + + creds = mock.Mock() + creds.token = 'fake_token' + mock_auth.return_value = [creds, 'project'] + + mock_polled_bq_job = mock.Mock() + mock_polled_bq_job.json.return_value = { + 'selfLink': self._job_uri, + 'status': { + 'state': 'DONE', + 'errorResult': { + 'foo': 'bar' + } + }, + 'configuration': { + 'query': { + 'destinationTable': { + 'projectId': 'test_project', + 'datasetId': 'test_dataset', + 'tableId': 'test_table' + } + } + } + } + mock_get_requests.return_value = mock_polled_bq_job + + with self.assertRaises(RuntimeError): + bigquery_query_job_remote_runner.create_bigquery_job( + self._job_type, self._project, self._location, self._payload, + self._gcp_resources, self._executor_input) + + self.assertEqual(mock_time_sleep.call_count, 1) + self.assertEqual(mock_get_requests.call_count, 1) diff --git a/components/google-cloud/tests/container/experimental/gcp_launcher/test_launcher.py b/components/google-cloud/tests/container/experimental/gcp_launcher/test_launcher.py index 08bd1439d902..5c5e3a39a4ed 100644 --- a/components/google-cloud/tests/container/experimental/gcp_launcher/test_launcher.py +++ b/components/google-cloud/tests/container/experimental/gcp_launcher/test_launcher.py @@ -23,124 +23,157 @@ class LauncherJobUtilsTests(unittest.TestCase): - def setUp(self): - super(LauncherJobUtilsTests, self).setUp() - self._project = 'test_project' - self._location = 'test_region' - self._gcp_resources = os.path.join(os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), "test_file_path/test_file.txt") - - @mock.patch.object( - google_cloud_pipeline_components.google_cloud_pipeline_components.container.experimental.gcp_launcher - .custom_job_remote_runner, - 'create_custom_job', - autospec=True) - def test_launcher_on_custom_job_type_calls_custom_job_remote_runner( - self, mock_custom_job_remote_runner): - job_type = 'CustomJob' - payload = ( - '{"display_name": "ContainerComponent", "job_spec": ' - '{"worker_pool_specs": [{"machine_spec": {"machine_type": ' - '"n1-standard-4"}, "replica_count": 1, "container_spec": ' - '{"image_uri": "google/cloud-sdk:latest", "command": ["sh", ' - '"-c", "set -e -x\\necho \\"$0, this is an output ' - 'parameter\\"\\n", "{{$.inputs.parameters[\'input_text\']}}", ' - '"{{$.outputs.parameters[\'output_value\'].output_file}}"]}}]}}') - input_args = [ - '--type', job_type, '--project', self._project, '--location', - self._location, '--payload', payload, '--gcp_resources', - self._gcp_resources, '--extra_arg', 'extra_arg_value' - ] - launcher.main(input_args) - mock_custom_job_remote_runner.assert_called_once_with( - type=job_type, - project=self._project, - location=self._location, - payload=payload, - gcp_resources=self._gcp_resources) - - @mock.patch.object( - google_cloud_pipeline_components.google_cloud_pipeline_components.container.experimental.gcp_launcher - .batch_prediction_job_remote_runner, - 'create_batch_prediction_job', - autospec=True) - def test_launcher_on_batch_prediction_job_type_calls_batch_prediction_job_remote_runner( - self, mock_batch_prediction_job_remote_runner): - job_type = 'BatchPredictionJob' - payload = ( - '{"batchPredictionJob": {"displayName": ' - '"BatchPredictionComponentName", "model": ' - '"projects/test/locations/test/models/test-model","inputConfig":' - ' {"instancesFormat": "CSV","gcsSource": {"uris": ' - '["test_gcs_source"]}}, "outputConfig": {"predictionsFormat": ' - '"CSV", "gcsDestination": {"outputUriPrefix": ' - '"test_gcs_destination"}}}}') - input_args = [ - '--type', job_type, '--project', self._project, '--location', - self._location, '--payload', payload, '--gcp_resources', - self._gcp_resources, '--executor_input', 'executor_input' - ] - launcher.main(input_args) - mock_batch_prediction_job_remote_runner.assert_called_once_with( - type=job_type, - project=self._project, - location=self._location, - payload=payload, - gcp_resources=self._gcp_resources, - executor_input='executor_input') + def setUp(self): + super(LauncherJobUtilsTests, self).setUp() + self._project = 'test_project' + self._location = 'test_region' + self._gcp_resources = os.path.join( + os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), + 'test_file_path/test_file.txt') + + @mock.patch.object( + google_cloud_pipeline_components.google_cloud_pipeline_components + .container.experimental.gcp_launcher.custom_job_remote_runner, + 'create_custom_job', + autospec=True) + def test_launcher_on_custom_job_type_calls_custom_job_remote_runner( + self, mock_custom_job_remote_runner): + job_type = 'CustomJob' + payload = ('{"display_name": "ContainerComponent", "job_spec": ' + '{"worker_pool_specs": [{"machine_spec": {"machine_type": ' + '"n1-standard-4"}, "replica_count": 1, "container_spec": ' + '{"image_uri": "google/cloud-sdk:latest", "command": ["sh", ' + '"-c", "set -e -x\\necho \\"$0, this is an output ' + 'parameter\\"\\n", "{{$.inputs.parameters[\'input_text\']}}", ' + '"{{$.outputs.parameters[\'output_value\'].output_file}}"]}}]}}') + input_args = [ + '--type', job_type, '--project', self._project, '--location', + self._location, '--payload', payload, '--gcp_resources', + self._gcp_resources, '--extra_arg', 'extra_arg_value' + ] + launcher.main(input_args) + mock_custom_job_remote_runner.assert_called_once_with( + type=job_type, + project=self._project, + location=self._location, + payload=payload, + gcp_resources=self._gcp_resources) + + @mock.patch.object( + google_cloud_pipeline_components.google_cloud_pipeline_components + .container.experimental.gcp_launcher.batch_prediction_job_remote_runner, + 'create_batch_prediction_job', + autospec=True) + def test_launcher_on_batch_prediction_job_type_calls_batch_prediction_job_remote_runner( + self, mock_batch_prediction_job_remote_runner): + job_type = 'BatchPredictionJob' + payload = ('{"batchPredictionJob": {"displayName": ' + '"BatchPredictionComponentName", "model": ' + '"projects/test/locations/test/models/test-model","inputConfig":' + ' {"instancesFormat": "CSV","gcsSource": {"uris": ' + '["test_gcs_source"]}}, "outputConfig": {"predictionsFormat": ' + '"CSV", "gcsDestination": {"outputUriPrefix": ' + '"test_gcs_destination"}}}}') + input_args = [ + '--type', job_type, '--project', self._project, '--location', + self._location, '--payload', payload, '--gcp_resources', + self._gcp_resources, '--executor_input', 'executor_input' + ] + launcher.main(input_args) + mock_batch_prediction_job_remote_runner.assert_called_once_with( + type=job_type, + project=self._project, + location=self._location, + payload=payload, + gcp_resources=self._gcp_resources, + executor_input='executor_input') class LauncherUploadModelUtilsTests(unittest.TestCase): - def setUp(self): - super(LauncherUploadModelUtilsTests, self).setUp() - self._gcp_resources = os.path.join(os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), "test_file_path/test_file.txt") - self._input_args = [ - "--type", "UploadModel", "--project", "test_project", "--location", - "us_central1", "--payload", "test_payload", "--gcp_resources", - self._gcp_resources, "--executor_input", "executor_input" - ] - - @mock.patch.object( - google_cloud_pipeline_components.google_cloud_pipeline_components.container.experimental.gcp_launcher - .upload_model_remote_runner, - "upload_model", - autospec=True) - def test_launcher_on_upload_model_type_calls_upload_model_remote_runner( - self, mock_upload_model_remote_runner): - launcher.main(self._input_args) - mock_upload_model_remote_runner.assert_called_once_with( - type='UploadModel', - project='test_project', - location='us_central1', - payload='test_payload', - gcp_resources=self._gcp_resources, - executor_input='executor_input') + def setUp(self): + super(LauncherUploadModelUtilsTests, self).setUp() + self._gcp_resources = os.path.join( + os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), + 'test_file_path/test_file.txt') + self._input_args = [ + '--type', 'UploadModel', '--project', 'test_project', '--location', + 'us_central1', '--payload', 'test_payload', '--gcp_resources', + self._gcp_resources, '--executor_input', 'executor_input' + ] + + @mock.patch.object( + google_cloud_pipeline_components.google_cloud_pipeline_components + .container.experimental.gcp_launcher.upload_model_remote_runner, + 'upload_model', + autospec=True) + def test_launcher_on_upload_model_type_calls_upload_model_remote_runner( + self, mock_upload_model_remote_runner): + launcher.main(self._input_args) + mock_upload_model_remote_runner.assert_called_once_with( + type='UploadModel', + project='test_project', + location='us_central1', + payload='test_payload', + gcp_resources=self._gcp_resources, + executor_input='executor_input') class LauncherCreateEndpointUtilsTests(unittest.TestCase): - def setUp(self): - super(LauncherCreateEndpointUtilsTests, self).setUp() - self._gcp_resources = os.path.join(os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), "test_file_path/test_file.txt") - self._input_args = [ - "--type", "CreateEndpoint", "--project", "test_project", - "--location", "us_central1", "--payload", "test_payload", - "--gcp_resources", self._gcp_resources, - "--executor_input", "executor_input" - ] - - @mock.patch.object( - google_cloud_pipeline_components.google_cloud_pipeline_components.container.experimental.gcp_launcher - .create_endpoint_remote_runner, - "create_endpoint", - autospec=True) - def test_launcher_on_create_endpoint_type_calls_create_endpoint_remote_runner( - self, create_endpoint_remote_runner): - launcher.main(self._input_args) - create_endpoint_remote_runner.assert_called_once_with( - type='CreateEndpoint', - project='test_project', - location='us_central1', - payload='test_payload', - gcp_resources=self._gcp_resources, - executor_input='executor_input') + def setUp(self): + super(LauncherCreateEndpointUtilsTests, self).setUp() + self._gcp_resources = os.path.join( + os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), + 'test_file_path/test_file.txt') + self._input_args = [ + '--type', 'CreateEndpoint', '--project', 'test_project', '--location', + 'us_central1', '--payload', 'test_payload', '--gcp_resources', + self._gcp_resources, '--executor_input', 'executor_input' + ] + + @mock.patch.object( + google_cloud_pipeline_components.google_cloud_pipeline_components + .container.experimental.gcp_launcher.create_endpoint_remote_runner, + 'create_endpoint', + autospec=True) + def test_launcher_on_create_endpoint_type_calls_create_endpoint_remote_runner( + self, create_endpoint_remote_runner): + launcher.main(self._input_args) + create_endpoint_remote_runner.assert_called_once_with( + type='CreateEndpoint', + project='test_project', + location='us_central1', + payload='test_payload', + gcp_resources=self._gcp_resources, + executor_input='executor_input') + + +class LauncherBigqueryQueryJobUtilsTests(unittest.TestCase): + + def setUp(self): + super(LauncherBigqueryQueryJobUtilsTests, self).setUp() + self._gcp_resources = os.path.join( + os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), + 'test_file_path/test_file.txt') + self._input_args = [ + '--type', 'BigqueryQueryJob', '--project', 'test_project', '--location', + 'us_central1', '--payload', 'test_payload', '--gcp_resources', + self._gcp_resources, '--executor_input', 'executor_input' + ] + + @mock.patch.object( + google_cloud_pipeline_components.google_cloud_pipeline_components + .container.experimental.gcp_launcher.bigquery_query_job_remote_runner, + 'create_bigquery_job', + autospec=True) + def test_launcher_on_bigquery_query_job_type( + self, bigquery_query_job_remote_runner): + launcher.main(self._input_args) + bigquery_query_job_remote_runner.assert_called_once_with( + type='BigqueryQueryJob', + project='test_project', + location='us_central1', + payload='test_payload', + gcp_resources=self._gcp_resources, + executor_input='executor_input')