From 3a634cf7b24100ac9f77cf4caad2c6e369c22720 Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Tue, 3 Nov 2020 11:33:37 +0800 Subject: [PATCH] Implement list job and get job methods for Dataproc launcher (#1106) * Implement list job and get job methods for Dataproc launcher Signed-off-by: Khor Shu Heng * Add error message to SparkJobFailure Signed-off-by: Khor Shu Heng * Add integration test for dataproc launcher Signed-off-by: Khor Shu Heng * Remove extra line space Signed-off-by: Khor Shu Heng * Fix argument name to integration test script Signed-off-by: Khor Shu Heng * Add error handling for cancelled job Signed-off-by: Khor Shu Heng * Remove dummy test Signed-off-by: Khor Shu Heng Co-authored-by: Khor Shu Heng --- infra/scripts/test-integration.sh | 2 +- sdk/python/feast/pyspark/abc.py | 51 ++++- .../pyspark/launchers/gcloud/dataproc.py | 190 +++++++++++++++--- tests/integration/conftest.py | 7 +- tests/integration/fixtures/__init__.py | 0 tests/integration/fixtures/job_parameters.py | 115 +++++++++++ tests/integration/fixtures/launchers.py | 17 ++ tests/integration/test_launchers.py | 31 +++ tests/integration/test_simple.py | 2 - 9 files changed, 369 insertions(+), 46 deletions(-) create mode 100644 tests/integration/fixtures/__init__.py create mode 100644 tests/integration/fixtures/job_parameters.py create mode 100644 tests/integration/fixtures/launchers.py create mode 100644 tests/integration/test_launchers.py delete mode 100644 tests/integration/test_simple.py diff --git a/infra/scripts/test-integration.sh b/infra/scripts/test-integration.sh index e2116976f9..2a655cfedd 100755 --- a/infra/scripts/test-integration.sh +++ b/infra/scripts/test-integration.sh @@ -4,4 +4,4 @@ python -m pip install --upgrade pip setuptools wheel make install-python python -m pip install -qr tests/requirements.txt -pytest tests/integration/ \ No newline at end of file +pytest tests/integration --dataproc-cluster-name feast-e2e --dataproc-project kf-feast --dataproc-region us-central1 --dataproc-staging-location gs://feast-templocation-kf-feast \ No newline at end of file diff --git a/sdk/python/feast/pyspark/abc.py b/sdk/python/feast/pyspark/abc.py index 52f421d7e0..d350d764c9 100644 --- a/sdk/python/feast/pyspark/abc.py +++ b/sdk/python/feast/pyspark/abc.py @@ -25,6 +25,15 @@ class SparkJobStatus(Enum): COMPLETED = 3 +class SparkJobType(Enum): + HISTORICAL_RETRIEVAL = 0 + BATCH_INGESTION = 1 + STREAM_INGESTION = 2 + + def to_pascal_case(self): + return self.name.title().replace("_", "") + + class SparkJob(abc.ABC): """ Base class for all spark jobs @@ -45,7 +54,8 @@ def get_status(self) -> SparkJobStatus: """ Job Status retrieval - :return: SparkJobStatus + Returns: + SparkJobStatus: Job status """ raise NotImplementedError @@ -62,7 +72,17 @@ class SparkJobParameters(abc.ABC): def get_name(self) -> str: """ Getter for job name - :return: Job name + Returns: + str: Job name. + """ + raise NotImplementedError + + @abc.abstractmethod + def get_job_type(self) -> SparkJobType: + """ + Getter for job type. + Returns: + SparkJobType: Job type enum. """ raise NotImplementedError @@ -70,14 +90,16 @@ def get_name(self) -> str: def get_main_file_path(self) -> str: """ Getter for jar | python path - :return: Full path to file + Returns: + str: Full path to file. """ raise NotImplementedError def get_class_name(self) -> Optional[str]: """ Getter for main class name if it's applicable - :return: java class path, e.g. feast.ingestion.IngestionJob + Returns: + Optional[str]: java class path, e.g. feast.ingestion.IngestionJob. """ return None @@ -86,7 +108,8 @@ def get_arguments(self) -> List[str]: """ Getter for job arguments E.g., ["--source", '{"kafka":...}', ...] - :return: List of arguments + Returns: + List[str]: List of arguments. """ raise NotImplementedError @@ -94,7 +117,8 @@ def get_arguments(self) -> List[str]: def get_extra_options(self) -> str: """ Spark job dependencies (expected to resolved from maven) - :return: + Returns: + str: Spark job dependencies. """ raise NotImplementedError @@ -222,7 +246,10 @@ def __init__( def get_name(self) -> str: all_feature_tables_names = [ft["name"] for ft in self._feature_tables] - return f"HistoryRetrieval-{'-'.join(all_feature_tables_names)}" + return f"{self.get_job_type().to_pascal_case()}-{'-'.join(all_feature_tables_names)}" + + def get_job_type(self) -> SparkJobType: + return SparkJobType.HISTORICAL_RETRIEVAL def get_main_file_path(self) -> str: return os.path.join( @@ -302,10 +329,13 @@ def __init__( def get_name(self) -> str: return ( - f"BatchIngestion-{self.get_feature_table_name()}-" + f"{self.get_job_type().to_pascal_case()}-{self.get_feature_table_name()}-" f"{self._start.strftime('%Y-%m-%d')}-{self._end.strftime('%Y-%m-%d')}" ) + def get_job_type(self) -> SparkJobType: + return SparkJobType.BATCH_INGESTION + def _get_redis_config(self): return dict(host=self._redis_host, port=self._redis_port, ssl=self._redis_ssl) @@ -360,7 +390,10 @@ def __init__( self._extra_options = extra_options def get_name(self) -> str: - return f"StreamIngestion-{self.get_feature_table_name()}" + return f"{self.get_job_type().to_pascal_case()}-{self.get_feature_table_name()}" + + def get_job_type(self) -> SparkJobType: + return SparkJobType.STREAM_INGESTION def _get_redis_config(self): return dict(host=self._redis_host, port=self._redis_port, ssl=self._redis_ssl) diff --git a/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py b/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py index 66e548770a..3a428a9f26 100644 --- a/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py +++ b/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py @@ -1,12 +1,12 @@ +import json import os +import time import uuid from functools import partial -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.parse import urlparse -from google.api_core.operation import Operation -from google.cloud import dataproc_v1 -from google.cloud.dataproc_v1 import JobStatus +from google.cloud.dataproc_v1 import Job, JobControllerClient, JobStatus from feast.pyspark.abc import ( BatchIngestionJob, @@ -18,6 +18,7 @@ SparkJobFailure, SparkJobParameters, SparkJobStatus, + SparkJobType, StreamIngestionJob, StreamIngestionJobParameters, ) @@ -25,22 +26,45 @@ class DataprocJobMixin: - def __init__(self, operation: Operation, cancel_fn: Callable[[], None]): + def __init__( + self, job: Job, refresh_fn: Callable[[], Job], cancel_fn: Callable[[], None] + ): """ - :param operation: (google.api.core.operation.Operation): A Future for the spark job result, - returned by the dataproc client. + Implementation of common methods for different types of SparkJob running on Dataproc cluster. + + Args: + job (Job): Dataproc job resource. + refresh_fn (Callable[[], Job]): A function that returns the latest job resource. + cancel_fn (Callable[[], None]): A function that cancel the current job. """ - self._operation = operation + self._job = job + self._refresh_fn = refresh_fn self._cancel_fn = cancel_fn def get_id(self) -> str: - return self._operation.metadata.job_id + """ + Getter for the job id. + + Returns: + str: Dataproc job id. + """ + return self._job.reference.job_id def get_status(self) -> SparkJobStatus: - self._operation._refresh_and_update() + """ + Job Status retrieval - status = self._operation.metadata.status - if status.state == JobStatus.State.ERROR: + Returns: + SparkJobStatus: Job status + """ + self._job = self._refresh_fn() + status = self._job.status + if status.state in ( + JobStatus.State.ERROR, + JobStatus.State.CANCEL_PENDING, + JobStatus.State.CANCEL_STARTED, + JobStatus.State.CANCELLED, + ): return SparkJobStatus.FAILED elif status.state == JobStatus.State.RUNNING: return SparkJobStatus.IN_PROGRESS @@ -54,8 +78,59 @@ def get_status(self) -> SparkJobStatus: return SparkJobStatus.COMPLETED def cancel(self): + """ + Manually terminate job + """ self._cancel_fn() + def get_error_message(self) -> Optional[str]: + """ + Getter for the job's error message if applicable. + + Returns: + str: Status detail of the job. Return None if the job is successful. + """ + self._job = self._refresh_fn() + status = self._job.status + if status.state == JobStatus.State.ERROR: + return status.details + elif status.state in ( + JobStatus.State.CANCEL_PENDING, + JobStatus.State.CANCEL_STARTED, + JobStatus.State.CANCELLED, + ): + return "Job was cancelled." + return None + + def block_polling(self, interval_sec=30, timeout_sec=3600) -> SparkJobStatus: + """ + Blocks until the Dataproc job is completed or failed. + + Args: + interval_sec (int): Polling interval. + timeout_sec (int): Timeout limit. + + Returns: + SparkJobStatus: Latest job status + + Raise: + SparkJobFailure: Raise error if the job neither failed nor completed within the timeout limit. + """ + + start = time.time() + while True: + elapsed_time = time.time() - start + if timeout_sec and elapsed_time >= timeout_sec: + raise SparkJobFailure( + f"Job is still not completed after {timeout_sec}." + ) + + status = self.get_status() + if status in [SparkJobStatus.FAILED, SparkJobStatus.COMPLETED]: + break + time.sleep(interval_sec) + return status + class DataprocRetrievalJob(DataprocJobMixin, RetrievalJob): """ @@ -63,7 +138,11 @@ class DataprocRetrievalJob(DataprocJobMixin, RetrievalJob): """ def __init__( - self, operation: Operation, cancel_fn: Callable[[], None], output_file_uri: str + self, + job: Job, + refresh_fn: Callable[[], Job], + cancel_fn: Callable[[], None], + output_file_uri: str, ): """ This is the returned historical feature retrieval job result for DataprocClusterLauncher. @@ -71,18 +150,17 @@ def __init__( Args: output_file_uri (str): Uri to the historical feature retrieval job output file. """ - super().__init__(operation, cancel_fn) + super().__init__(job, refresh_fn, cancel_fn) self._output_file_uri = output_file_uri def get_output_file_uri(self, timeout_sec=None, block=True): if not block: return self._output_file_uri - try: - self._operation.result(timeout_sec) - except Exception as err: - raise SparkJobFailure(err) - return self._output_file_uri + status = self.block_polling(timeout_sec=timeout_sec) + if status == SparkJobStatus.COMPLETED: + return self._output_file_uri + raise SparkJobFailure(self.get_error_message()) class DataprocBatchIngestionJob(DataprocJobMixin, BatchIngestionJob): @@ -105,6 +183,7 @@ class DataprocClusterLauncher(JobLauncher): """ EXTERNAL_JARS = ["gs://spark-lib/bigquery/spark-bigquery-latest_2.12.jar"] + JOB_TYPE_LABEL_KEY = "feast_job_type" def __init__( self, cluster_name: str, staging_location: str, region: str, project_id: str, @@ -135,7 +214,7 @@ def __init__( ) self.project_id = project_id self.region = region - self.job_client = dataproc_v1.JobControllerClient( + self.job_client = JobControllerClient( client_options={"api_endpoint": f"{region}-dataproc.googleapis.com:443"} ) @@ -148,12 +227,15 @@ def _stage_files(self, pyspark_script: str, job_id: str) -> str: return f"gs://{self.staging_bucket}/{blob_path}" - def dataproc_submit(self, job_params: SparkJobParameters) -> Operation: + def dataproc_submit( + self, job_params: SparkJobParameters + ) -> Tuple[Job, Callable[[], Job], Callable[[], None]]: local_job_id = str(uuid.uuid4()) main_file_uri = self._stage_files(job_params.get_main_file_path(), local_job_id) job_config: Dict[str, Any] = { "reference": {"job_id": local_job_id}, "placement": {"cluster_name": self.cluster_name}, + "labels": {self.JOB_TYPE_LABEL_KEY: job_params.get_job_type().name.lower()}, } if job_params.get_class_name(): job_config.update( @@ -175,7 +257,8 @@ def dataproc_submit(self, job_params: SparkJobParameters) -> Operation: } } ) - return self.job_client.submit_job_as_operation( + + job = self.job_client.submit_job( request={ "project_id": self.project_id, "region": self.region, @@ -183,6 +266,16 @@ def dataproc_submit(self, job_params: SparkJobParameters) -> Operation: } ) + refresh_fn = partial( + self.job_client.get_job, + project_id=self.project_id, + region=self.region, + job_id=job.reference.job_id, + ) + cancel_fn = partial(self.dataproc_cancel, job.reference.job_id) + + return job, refresh_fn, cancel_fn + def dataproc_cancel(self, job_id): self.job_client.cancel_job( project_id=self.project_id, region=self.region, job_id=job_id @@ -191,31 +284,62 @@ def dataproc_cancel(self, job_id): def historical_feature_retrieval( self, job_params: RetrievalJobParameters ) -> RetrievalJob: - operation = self.dataproc_submit(job_params) - cancel_fn = partial(self.dataproc_cancel, operation.metadata.job_id) + job, refresh_fn, cancel_fn = self.dataproc_submit(job_params) return DataprocRetrievalJob( - operation, cancel_fn, job_params.get_destination_path() + job, refresh_fn, cancel_fn, job_params.get_destination_path() ) def offline_to_online_ingestion( self, ingestion_job_params: BatchIngestionJobParameters ) -> BatchIngestionJob: - operation = self.dataproc_submit(ingestion_job_params) - cancel_fn = partial(self.dataproc_cancel, operation.metadata.job_id) - return DataprocBatchIngestionJob(operation, cancel_fn) + job, refresh_fn, cancel_fn = self.dataproc_submit(ingestion_job_params) + return DataprocBatchIngestionJob(job, refresh_fn, cancel_fn) def start_stream_to_online_ingestion( self, ingestion_job_params: StreamIngestionJobParameters ) -> StreamIngestionJob: - operation = self.dataproc_submit(ingestion_job_params) - cancel_fn = partial(self.dataproc_cancel, operation.metadata.job_id) - return DataprocStreamingIngestionJob(operation, cancel_fn) + job, refresh_fn, cancel_fn = self.dataproc_submit(ingestion_job_params) + return DataprocStreamingIngestionJob(job, refresh_fn, cancel_fn) def stage_dataframe(self, df, event_timestamp_column: str): raise NotImplementedError def get_job_by_id(self, job_id: str) -> SparkJob: - raise NotImplementedError + job = self.job_client.get_job( + project_id=self.project_id, region=self.region, job_id=job_id + ) + return self._dataproc_job_to_spark_job(job) + + def _dataproc_job_to_spark_job(self, job: Job) -> SparkJob: + job_type = job.labels[self.JOB_TYPE_LABEL_KEY] + job_id = job.reference.job_id + refresh_fn = partial( + self.job_client.get_job, + project_id=self.project_id, + region=self.region, + job_id=job_id, + ) + cancel_fn = partial(self.dataproc_cancel, job_id) + + if job_type == SparkJobType.HISTORICAL_RETRIEVAL.name.lower(): + output_path = json.loads(job.pyspark_job.args[-1])["path"] + return DataprocRetrievalJob(job, refresh_fn, cancel_fn, output_path) + + if job_type == SparkJobType.BATCH_INGESTION.name.lower(): + return DataprocBatchIngestionJob(job, refresh_fn, cancel_fn) + + if job_type == SparkJobType.STREAM_INGESTION.name.lower(): + return DataprocStreamingIngestionJob(job, refresh_fn, cancel_fn) + + raise ValueError(f"Unrecognized job type: {job_type}") def list_jobs(self, include_terminated: bool) -> List[SparkJob]: - raise NotImplementedError + job_filter = f"labels.{self.JOB_TYPE_LABEL_KEY} = * AND clusterName = {self.cluster_name}" + if not include_terminated: + job_filter = job_filter + "AND status.state = ACTIVE" + return [ + self._dataproc_job_to_spark_job(job) + for job in self.job_client.list_jobs( + project_id=self.project_id, region=self.region, filter=job_filter + ) + ] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4b82f1890c..88ab8743a5 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,2 +1,7 @@ def pytest_addoption(parser): - pass + parser.addoption("--dataproc-cluster-name", action="store") + parser.addoption("--dataproc-region", action="store") + parser.addoption("--dataproc-project", action="store") + parser.addoption("--dataproc-staging-location", action="store") + parser.addoption("--redis-url", action="store") + parser.addoption("--redis-cluster", action="store_true") diff --git a/tests/integration/fixtures/__init__.py b/tests/integration/fixtures/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/fixtures/job_parameters.py b/tests/integration/fixtures/job_parameters.py new file mode 100644 index 0000000000..e1024354ae --- /dev/null +++ b/tests/integration/fixtures/job_parameters.py @@ -0,0 +1,115 @@ +import tempfile +import uuid +from datetime import datetime +from os import path +from urllib.parse import urlparse + +import numpy as np +import pandas as pd +import pytest +from google.cloud import storage +from pytz import utc + +from feast.pyspark.abc import RetrievalJobParameters + + +@pytest.fixture(scope="module") +def customer_entity() -> pd.DataFrame: + return pd.DataFrame( + np.array([[1001, datetime(year=2020, month=9, day=1, tzinfo=utc)]]), + columns=["customer_id", "event_timestamp"], + ) + + +@pytest.fixture(scope="module") +def customer_feature() -> pd.DataFrame: + return pd.DataFrame( + np.array( + [ + [ + 1001, + 100.0, + datetime(year=2020, month=9, day=1, tzinfo=utc), + datetime(year=2020, month=9, day=1, tzinfo=utc), + ], + ] + ), + columns=[ + "customer_id", + "total_transactions", + "event_timestamp", + "created_timestamp", + ], + ) + + +def upload_dataframe_to_gcs_as_parquet(df: pd.DataFrame, staging_location: str): + gcs_client = storage.Client() + staging_location_uri = urlparse(staging_location) + staging_bucket = staging_location_uri.netloc + remote_path = staging_location_uri.path.lstrip("/") + gcs_bucket = gcs_client.get_bucket(staging_bucket) + temp_dir = str(uuid.uuid4()) + df_remote_path = path.join(remote_path, temp_dir) + blob = gcs_bucket.blob(df_remote_path) + with tempfile.NamedTemporaryFile() as df_local_path: + df.to_parquet(df_local_path.name) + blob.upload_from_filename(df_local_path.name) + return path.join(staging_location, df_remote_path) + + +def new_retrieval_job_params( + entity_source_uri: str, feature_source_uri: str, destination_uri: str +) -> RetrievalJobParameters: + entity_source = { + "file": { + "format": "parquet", + "path": entity_source_uri, + "event_timestamp_column": "event_timestamp", + } + } + + feature_tables_sources = [ + { + "file": { + "format": "parquet", + "path": feature_source_uri, + "event_timestamp_column": "event_timestamp", + "created_timestamp_column": "created_timestamp", + } + } + ] + + feature_tables = [ + { + "name": "customer_transactions", + "entities": [{"name": "customer", "type": "int32"}], + } + ] + + destination = {"format": "parquet", "path": destination_uri} + + return RetrievalJobParameters( + feature_tables=feature_tables, + feature_tables_sources=feature_tables_sources, + entity_source=entity_source, + destination=destination, + ) + + +@pytest.fixture(scope="module") +def dataproc_retrieval_job_params( + pytestconfig, customer_entity, customer_feature +) -> RetrievalJobParameters: + staging_location = pytestconfig.getoption("--dataproc-staging-location") + entity_source_uri = upload_dataframe_to_gcs_as_parquet( + customer_entity, staging_location + ) + feature_source_uri = upload_dataframe_to_gcs_as_parquet( + customer_feature, staging_location + ) + destination_uri = path.join(staging_location, str(uuid.uuid4())) + + return new_retrieval_job_params( + entity_source_uri, feature_source_uri, destination_uri + ) diff --git a/tests/integration/fixtures/launchers.py b/tests/integration/fixtures/launchers.py new file mode 100644 index 0000000000..ebe93172d1 --- /dev/null +++ b/tests/integration/fixtures/launchers.py @@ -0,0 +1,17 @@ +import pytest + +from feast.pyspark.launchers.gcloud import DataprocClusterLauncher + + +@pytest.fixture +def dataproc_launcher(pytestconfig) -> DataprocClusterLauncher: + cluster_name = pytestconfig.getoption("--dataproc-cluster-name") + region = pytestconfig.getoption("--dataproc-region") + project_id = pytestconfig.getoption("--dataproc-project") + staging_location = pytestconfig.getoption("--dataproc-staging-location") + return DataprocClusterLauncher( + cluster_name=cluster_name, + staging_location=staging_location, + region=region, + project_id=project_id, + ) diff --git a/tests/integration/test_launchers.py b/tests/integration/test_launchers.py new file mode 100644 index 0000000000..fb6731afa2 --- /dev/null +++ b/tests/integration/test_launchers.py @@ -0,0 +1,31 @@ +from time import sleep + +from feast.pyspark.abc import RetrievalJobParameters, SparkJobStatus +from feast.pyspark.launchers.gcloud import DataprocClusterLauncher + +from .fixtures.job_parameters import customer_entity # noqa: F401 +from .fixtures.job_parameters import customer_feature # noqa: F401 +from .fixtures.job_parameters import dataproc_retrieval_job_params # noqa: F401 +from .fixtures.launchers import dataproc_launcher # noqa: F401 + + +def test_dataproc_job_api( + dataproc_launcher: DataprocClusterLauncher, # noqa: F811 + dataproc_retrieval_job_params: RetrievalJobParameters, # noqa: F811 +): + job = dataproc_launcher.historical_feature_retrieval(dataproc_retrieval_job_params) + job_id = job.get_id() + retrieved_job = dataproc_launcher.get_job_by_id(job_id) + assert retrieved_job.get_id() == job_id + status = retrieved_job.get_status() + assert status in [ + SparkJobStatus.STARTING, + SparkJobStatus.IN_PROGRESS, + SparkJobStatus.COMPLETED, + ] + active_job_ids = [ + job.get_id() for job in dataproc_launcher.list_jobs(include_terminated=False) + ] + assert job_id in active_job_ids + retrieved_job.cancel() + assert retrieved_job.get_status() == SparkJobStatus.FAILED diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py deleted file mode 100644 index 70ce8efe43..0000000000 --- a/tests/integration/test_simple.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_success(): - pass