diff --git a/.github/workflows/master_only.yml b/.github/workflows/master_only.yml index 2852c559dd..20caebe0af 100644 --- a/.github/workflows/master_only.yml +++ b/.github/workflows/master_only.yml @@ -79,7 +79,15 @@ jobs: - uses: stCarolas/setup-maven@v3 with: maven-version: 3.6.3 - - name: build-jar + - name: Publish develop version of ingestion job + run: | + if [ ${GITHUB_REF#refs/*/} == "master" ]; then + make build-java-no-tests REVISION=develop + gsutil cp ./spark/ingestion/target/feast-ingestion-spark-develop.jar gs://${PUBLISH_BUCKET}/spark/ingestion/ + fi + - name: Get version + run: echo ::set-env name=RELEASE_VERSION::${GITHUB_REF#refs/*/} + - name: Publish tagged version of ingestion job run: | SEMVER_REGEX='^v[0-9]+\.[0-9]+\.[0-9]+(-([0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*))?$' if echo "${RELEASE_VERSION}" | grep -P "$SEMVER_REGEX" &>/dev/null ; then diff --git a/core/src/main/java/feast/core/validators/DataSourceValidator.java b/core/src/main/java/feast/core/validators/DataSourceValidator.java index 223906d0e6..f36e3609d7 100644 --- a/core/src/main/java/feast/core/validators/DataSourceValidator.java +++ b/core/src/main/java/feast/core/validators/DataSourceValidator.java @@ -51,6 +51,8 @@ public static void validate(DataSource spec) { spec.getKafkaOptions().getMessageFormat().getProtoFormat().getClassPath(), "FeatureTable"); break; + case AVRO_FORMAT: + break; default: throw new UnsupportedOperationException( String.format( @@ -68,6 +70,8 @@ public static void validate(DataSource spec) { spec.getKinesisOptions().getRecordFormat().getProtoFormat().getClassPath(), "FeatureTable"); break; + case AVRO_FORMAT: + break; default: throw new UnsupportedOperationException( String.format("Unsupported Stream Format for Kafka Source Type: %s", recordFormat)); diff --git a/infra/charts/feast/values.yaml b/infra/charts/feast/values.yaml index b3a2df67be..a42349c376 100644 --- a/infra/charts/feast/values.yaml +++ b/infra/charts/feast/values.yaml @@ -4,7 +4,7 @@ feast-core: feast-jobcontroller: # feast-jobcontroller.enabled -- Flag to install Feast Job Controller - enabled: true + enabled: false feast-online-serving: # feast-online-serving.enabled -- Flag to install Feast Online Serving diff --git a/sdk/python/feast/constants.py b/sdk/python/feast/constants.py index e55c748211..071e867d4c 100644 --- a/sdk/python/feast/constants.py +++ b/sdk/python/feast/constants.py @@ -123,7 +123,7 @@ class AuthProvider(Enum): # Authentication Provider - Google OpenID/OAuth CONFIG_AUTH_PROVIDER: "google", CONFIG_SPARK_LAUNCHER: "dataproc", - CONFIG_SPARK_INGESTION_JOB_JAR: "gs://feast-jobs/feast-ingestion-spark-0.8-SNAPSHOT.jar", + CONFIG_SPARK_INGESTION_JOB_JAR: "gs://feast-jobs/feast-ingestion-spark-develop.jar", CONFIG_REDIS_HOST: "localhost", CONFIG_REDIS_PORT: "6379", CONFIG_REDIS_SSL: "False", diff --git a/sdk/python/feast/pyspark/abc.py b/sdk/python/feast/pyspark/abc.py index ed65672950..45005e20d1 100644 --- a/sdk/python/feast/pyspark/abc.py +++ b/sdk/python/feast/pyspark/abc.py @@ -19,6 +19,7 @@ class SparkJobFailure(Exception): class SparkJobStatus(Enum): + STARTING = 0 IN_PROGRESS = 1 FAILED = 2 COMPLETED = 3 @@ -48,6 +49,13 @@ def get_status(self) -> SparkJobStatus: """ raise NotImplementedError + @abc.abstractmethod + def cancel(self): + """ + Manually terminate job + """ + raise NotImplementedError + class SparkJobParameters(abc.ABC): @abc.abstractmethod diff --git a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py index 1d221e0f32..3de7ec660f 100644 --- a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py +++ b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py @@ -75,6 +75,12 @@ class FileSource(Source): options (Optional[Dict[str, str]]): Options to be passed to spark while reading the file source. """ + PROTO_FORMAT_TO_SPARK = { + "ParquetFormat": "parquet", + "AvroFormat": "avro", + "CSVFormat": "csv", + } + def __init__( self, format: str, @@ -147,7 +153,7 @@ def spark_path(self) -> str: def _source_from_dict(dct: Dict) -> Source: if "file" in dct.keys(): return FileSource( - dct["file"]["format"], + FileSource.PROTO_FORMAT_TO_SPARK[dct["file"]["format"]["json_class"]], dct["file"]["path"], dct["file"]["event_timestamp_column"], dct["file"].get("created_timestamp_column"), @@ -635,7 +641,7 @@ def retrieve_historical_features( Example: >>> entity_source_conf = { - "format": "csv", + "format": {"jsonClass": "ParquetFormat"}, "path": "file:///some_dir/customer_driver_pairs.csv"), "options": {"inferSchema": "true", "header": "true"}, "field_mapping": {"id": "driver_id"} @@ -643,12 +649,12 @@ def retrieve_historical_features( >>> feature_tables_sources_conf = [ { - "format": "parquet", + "format": {"json_class": "ParquetFormat"}, "path": "gs://some_bucket/bookings.parquet"), "field_mapping": {"id": "driver_id"} }, { - "format": "avro", + "format": {"json_class": "AvroFormat", schema_json: "..avro schema.."}, "path": "s3://some_bucket/transactions.avro"), } ] diff --git a/sdk/python/feast/pyspark/launcher.py b/sdk/python/feast/pyspark/launcher.py index 5156a7c541..defd181442 100644 --- a/sdk/python/feast/pyspark/launcher.py +++ b/sdk/python/feast/pyspark/launcher.py @@ -86,9 +86,6 @@ def resolve_launcher(config: Config) -> JobLauncher: return _launchers[config.get(CONFIG_SPARK_LAUNCHER)](config) -_SOURCES = {FileSource: "file", BigQuerySource: "bq", KafkaSource: "kafka"} - - def _source_to_argument(source: DataSource): common_properties = { "field_mapping": dict(source.field_mapping), @@ -97,20 +94,28 @@ def _source_to_argument(source: DataSource): "date_partition_column": source.date_partition_column, } - kind = _SOURCES[type(source)] properties = {**common_properties} + if isinstance(source, FileSource): properties["path"] = source.file_options.file_url - properties["format"] = str(source.file_options.file_format) - return {kind: properties} + properties["format"] = dict( + json_class=source.file_options.file_format.__class__.__name__ + ) + return {"file": properties} + if isinstance(source, BigQuerySource): properties["table_ref"] = source.bigquery_options.table_ref - return {kind: properties} + return {"bq": properties} + if isinstance(source, KafkaSource): - properties["topic"] = source.kafka_options.topic - properties["classpath"] = source.kafka_options.class_path properties["bootstrap_servers"] = source.kafka_options.bootstrap_servers - return {kind: properties} + properties["topic"] = source.kafka_options.topic + properties["format"] = { + **source.kafka_options.message_format.__dict__, + "json_class": source.kafka_options.message_format.__class__.__name__, + } + return {"kafka": properties} + raise NotImplementedError(f"Unsupported Datasource: {type(source)}") diff --git a/sdk/python/feast/pyspark/launchers/aws/emr.py b/sdk/python/feast/pyspark/launchers/aws/emr.py index 3b87f91672..c23d07fb59 100644 --- a/sdk/python/feast/pyspark/launchers/aws/emr.py +++ b/sdk/python/feast/pyspark/launchers/aws/emr.py @@ -62,6 +62,9 @@ def get_status(self) -> SparkJobStatus: # we should never get here raise Exception("Invalid EMR state") + def cancel(self): + raise NotImplementedError + class EmrRetrievalJob(EmrJobMixin, RetrievalJob): """ diff --git a/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py b/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py index 1d6a0585c2..ff314194c9 100644 --- a/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py +++ b/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py @@ -45,6 +45,9 @@ def get_status(self) -> SparkJobStatus: return SparkJobStatus.FAILED + def cancel(self): + self._operation.cancel() + class DataprocRetrievalJob(DataprocJobMixin, RetrievalJob): """ @@ -71,7 +74,13 @@ def get_output_file_uri(self, timeout_sec=None): class DataprocBatchIngestionJob(DataprocJobMixin, BatchIngestionJob): """ - Ingestion job result for a Dataproc cluster + Batch Ingestion job result for a Dataproc cluster + """ + + +class DataprocStreamingIngestionJob(DataprocJobMixin, StreamIngestionJob): + """ + Streaming Ingestion job result for a Dataproc cluster """ @@ -151,14 +160,14 @@ def historical_feature_retrieval( ) def offline_to_online_ingestion( - self, job_params: BatchIngestionJobParameters + self, ingestion_job_params: BatchIngestionJobParameters ) -> BatchIngestionJob: - return DataprocBatchIngestionJob(self.dataproc_submit(job_params)) + return DataprocBatchIngestionJob(self.dataproc_submit(ingestion_job_params)) def start_stream_to_online_ingestion( self, ingestion_job_params: StreamIngestionJobParameters ) -> StreamIngestionJob: - raise NotImplementedError + return DataprocStreamingIngestionJob(self.dataproc_submit(ingestion_job_params)) def stage_dataframe( self, df, event_timestamp_column: str, created_timestamp_column: str, diff --git a/sdk/python/feast/pyspark/launchers/standalone/local.py b/sdk/python/feast/pyspark/launchers/standalone/local.py index 2ed1c26a08..fdbbb45064 100644 --- a/sdk/python/feast/pyspark/launchers/standalone/local.py +++ b/sdk/python/feast/pyspark/launchers/standalone/local.py @@ -1,6 +1,11 @@ import os +import socket import subprocess import uuid +from contextlib import closing + +import requests +from requests.exceptions import RequestException from feast.pyspark.abc import ( BatchIngestionJob, @@ -16,17 +21,53 @@ ) +def _find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + class StandaloneClusterJobMixin: - def __init__(self, job_id: str, process: subprocess.Popen): + def __init__( + self, job_id: str, job_name: str, process: subprocess.Popen, ui_port: int = None + ): self._job_id = job_id + self._job_name = job_name self._process = process + self._ui_port = ui_port def get_id(self) -> str: return self._job_id + def check_if_started(self): + if not self._ui_port: + return True + + try: + applications = requests.get( + f"http://localhost:{self._ui_port}/api/v1/applications" + ).json() + except RequestException: + return False + + app = next( + iter(app for app in applications if app["name"] == self._job_name), None + ) + if not app: + return False + + stages = requests.get( + f"http://localhost:{self._ui_port}/api/v1/applications/{app['id']}/stages" + ).json() + return bool(stages) + def get_status(self) -> SparkJobStatus: code = self._process.poll() if code is None: + if not self.check_if_started(): + return SparkJobStatus.STARTING + return SparkJobStatus.IN_PROGRESS if code != 0: @@ -34,10 +75,23 @@ def get_status(self) -> SparkJobStatus: return SparkJobStatus.COMPLETED + def cancel(self): + self._process.terminate() + class StandaloneClusterBatchIngestionJob(StandaloneClusterJobMixin, BatchIngestionJob): """ - Ingestion job result for a standalone spark cluster + Batch Ingestion job result for a standalone spark cluster + """ + + pass + + +class StandaloneClusterStreamingIngestionJob( + StandaloneClusterJobMixin, StreamIngestionJob +): + """ + Streaming Ingestion job result for a standalone spark cluster """ pass @@ -48,7 +102,13 @@ class StandaloneClusterRetrievalJob(StandaloneClusterJobMixin, RetrievalJob): Historical feature retrieval job result for a standalone spark cluster """ - def __init__(self, job_id: str, process: subprocess.Popen, output_file_uri: str): + def __init__( + self, + job_id: str, + job_name: str, + process: subprocess.Popen, + output_file_uri: str, + ): """ This is the returned historical feature retrieval job result for StandaloneClusterLauncher. @@ -57,7 +117,7 @@ def __init__(self, job_id: str, process: subprocess.Popen, output_file_uri: str) process (subprocess.Popen): Pyspark driver process, spawned by the launcher. output_file_uri (str): Uri to the historical feature retrieval job output file. """ - super().__init__(job_id, process) + super().__init__(job_id, job_name, process) self._output_file_uri = output_file_uri def get_output_file_uri(self, timeout_sec: int = None): @@ -100,7 +160,9 @@ def __init__(self, master_url: str, spark_home: str = None): def spark_submit_script_path(self): return os.path.join(self.spark_home, "bin/spark-submit") - def spark_submit(self, job_params: SparkJobParameters) -> subprocess.Popen: + def spark_submit( + self, job_params: SparkJobParameters, ui_port: int = None + ) -> subprocess.Popen: submission_cmd = [ self.spark_submit_script_path, "--master", @@ -112,6 +174,9 @@ def spark_submit(self, job_params: SparkJobParameters) -> subprocess.Popen: if job_params.get_class_name(): submission_cmd.extend(["--class", job_params.get_class_name()]) + if ui_port: + submission_cmd.extend(["--conf", f"spark.ui.port={ui_port}"]) + submission_cmd.append(job_params.get_main_file_path()) submission_cmd.extend(job_params.get_arguments()) @@ -122,19 +187,35 @@ def historical_feature_retrieval( ) -> RetrievalJob: job_id = str(uuid.uuid4()) return StandaloneClusterRetrievalJob( - job_id, self.spark_submit(job_params), job_params.get_destination_path() + job_id, + job_params.get_name(), + self.spark_submit(job_params), + job_params.get_destination_path(), ) def offline_to_online_ingestion( - self, job_params: BatchIngestionJobParameters + self, ingestion_job_params: BatchIngestionJobParameters ) -> BatchIngestionJob: job_id = str(uuid.uuid4()) - return StandaloneClusterBatchIngestionJob(job_id, self.spark_submit(job_params)) + ui_port = _find_free_port() + return StandaloneClusterBatchIngestionJob( + job_id, + ingestion_job_params.get_name(), + self.spark_submit(ingestion_job_params, ui_port), + ui_port, + ) def start_stream_to_online_ingestion( self, ingestion_job_params: StreamIngestionJobParameters ) -> StreamIngestionJob: - raise NotImplementedError + job_id = str(uuid.uuid4()) + ui_port = _find_free_port() + return StandaloneClusterStreamingIngestionJob( + job_id, + ingestion_job_params.get_name(), + self.spark_submit(ingestion_job_params, ui_port), + ui_port, + ) def stage_dataframe( self, df, event_timestamp_column: str, created_timestamp_column: str, diff --git a/sdk/python/requirements-ci.txt b/sdk/python/requirements-ci.txt index 62b335e557..eaeb81387e 100644 --- a/sdk/python/requirements-ci.txt +++ b/sdk/python/requirements-ci.txt @@ -16,3 +16,5 @@ pandavro==1.5.* moto mypy mypy-protobuf +avro==1.10.0 +confluent_kafka diff --git a/sdk/python/requirements-dev.txt b/sdk/python/requirements-dev.txt index ac7813ec79..f34c0c8924 100644 --- a/sdk/python/requirements-dev.txt +++ b/sdk/python/requirements-dev.txt @@ -26,7 +26,6 @@ Sphinx sphinx-rtd-theme toml==0.10.* tqdm==4.* -confluent_kafka google pandavro==1.5.* kafka-python==1.* diff --git a/sdk/python/tests/test_as_of_join.py b/sdk/python/tests/test_as_of_join.py index a6288855c7..23ce94e176 100644 --- a/sdk/python/tests/test_as_of_join.py +++ b/sdk/python/tests/test_as_of_join.py @@ -582,7 +582,7 @@ def test_historical_feature_retrieval(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") entity_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", "event_timestamp_column": "event_timestamp", "options": {"inferSchema": "true", "header": "true"}, @@ -590,7 +590,7 @@ def test_historical_feature_retrieval(spark: SparkSession): } booking_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'bookings.csv')}", "event_timestamp_column": "event_timestamp", "created_timestamp_column": "created_timestamp", @@ -599,7 +599,7 @@ def test_historical_feature_retrieval(spark: SparkSession): } transaction_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'transactions.csv')}", "event_timestamp_column": "event_timestamp", "created_timestamp_column": "created_timestamp", @@ -653,7 +653,7 @@ def test_historical_feature_retrieval_with_mapping(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") entity_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'column_mapping_test_entity.csv')}", "event_timestamp_column": "event_timestamp", "field_mapping": {"id": "customer_id"}, @@ -662,7 +662,7 @@ def test_historical_feature_retrieval_with_mapping(spark: SparkSession): } booking_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'column_mapping_test_feature.csv')}", "event_timestamp_column": "datetime", "created_timestamp_column": "created_datetime", @@ -723,7 +723,7 @@ def test_large_historical_feature_retrieval( entity_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{large_entity_csv_file}", "event_timestamp_column": "event_timestamp", "field_mapping": {"id": "customer_id"}, @@ -732,7 +732,7 @@ def test_large_historical_feature_retrieval( } feature_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{large_feature_csv_file}", "event_timestamp_column": "event_timestamp", "created_timestamp_column": "created_timestamp", @@ -755,7 +755,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") entity_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", "event_timestamp_column": "event_timestamp", "options": {"inferSchema": "true", "header": "true"}, @@ -763,7 +763,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession): } entity_source_missing_timestamp = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", "event_timestamp_column": "datetime", "options": {"inferSchema": "true", "header": "true"}, @@ -771,7 +771,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession): } entity_source_missing_entity = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'customers.csv')}", "event_timestamp_column": "event_timestamp", "options": {"inferSchema": "true", "header": "true"}, @@ -780,7 +780,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession): booking_source = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'bookings.csv')}", "event_timestamp_column": "event_timestamp", "created_timestamp_column": "created_timestamp", @@ -789,7 +789,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession): } booking_source_missing_timestamp = { "file": { - "format": "csv", + "format": {"json_class": "CSVFormat"}, "path": f"file://{path.join(test_data_dir, 'bookings.csv')}", "event_timestamp_column": "datetime", "created_timestamp_column": "created_datetime", diff --git a/spark/ingestion/pom.xml b/spark/ingestion/pom.xml index 74aab0f3db..6f51489086 100644 --- a/spark/ingestion/pom.xml +++ b/spark/ingestion/pom.xml @@ -109,6 +109,12 @@ ${spark.version} + + org.apache.spark + spark-avro_${scala.version} + ${spark.version} + + org.apache.kafka kafka-clients diff --git a/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala index 91b274d1ef..d958c2ea8d 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala @@ -28,9 +28,6 @@ trait BasePipeline { System.setProperty("io.netty.tryReflectionSetAccessible", "true") val conf = new SparkConf() - conf - .setAppName(s"${jobConfig.mode} IngestionJob for ${jobConfig.featureTable.name}") - .setMaster("local") jobConfig.store match { case RedisConfig(host, port, ssl) => diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala index 82168dbe81..88cc2b8b8e 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala @@ -26,7 +26,8 @@ object IngestionJob { import Modes._ implicit val modesRead: scopt.Read[Modes.Value] = scopt.Read.reads(Modes withName _.capitalize) implicit val formats: Formats = DefaultFormats + - new JavaEnumNameSerializer[feast.proto.types.ValueProto.ValueType.Enum]() + new JavaEnumNameSerializer[feast.proto.types.ValueProto.ValueType.Enum]() + + ShortTypeHints(List(classOf[ProtoFormat], classOf[AvroFormat])) val parser = new scopt.OptionParser[IngestionJobConfig]("IngestionJon") { // ToDo: read version from Manifest @@ -75,6 +76,7 @@ object IngestionJob { def main(args: Array[String]): Unit = { parser.parse(args, IngestionJobConfig()) match { case Some(config) => + println(s"Starting with config $config") config.mode match { case Modes.Offline => val sparkSession = BatchPipeline.createSparkSession(config) diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala index f13298ef2d..9ca3612180 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala @@ -32,6 +32,11 @@ abstract class MetricConfig case class StatsDConfig(host: String, port: Int) extends MetricConfig +abstract class DataFormat +case class ParquetFormat() extends DataFormat +case class ProtoFormat(classPath: String) extends DataFormat +case class AvroFormat(schemaJson: String) extends DataFormat + abstract class Source { def fieldMapping: Map[String, String] @@ -43,7 +48,7 @@ abstract class Source { abstract class BatchSource extends Source abstract class StreamingSource extends Source { - def classpath: String + def format: DataFormat } case class FileSource( @@ -67,7 +72,7 @@ case class BQSource( case class KafkaSource( bootstrapServers: String, topic: String, - override val classpath: String, + override val format: DataFormat, override val fieldMapping: Map[String, String], override val eventTimestampColumn: String, override val createdTimestampColumn: Option[String] = None, diff --git a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala index 576849f4b2..07fd20f2c1 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.functions.udf import feast.ingestion.utils.ProtoReflection import feast.ingestion.validation.{RowValidator, TypeCheck} import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.avro._ /** * Streaming pipeline (currently in micro-batches mode only, since we need to have multiple sinks: redis & deadletters). @@ -44,9 +45,6 @@ object StreamingPipeline extends BasePipeline with Serializable { inputProjection(config.source, featureTable.features, featureTable.entities) val validator = new RowValidator(featureTable, config.source.eventTimestampColumn) - val messageParser = - protoParser(sparkSession, config.source.asInstanceOf[StreamingSource].classpath) - val input = config.source match { case source: KafkaSource => sparkSession.readStream @@ -56,8 +54,15 @@ object StreamingPipeline extends BasePipeline with Serializable { .load() } - val projected = input - .withColumn("features", messageParser($"value")) + val parsed = config.source.asInstanceOf[StreamingSource].format match { + case ProtoFormat(classPath) => + val parser = protoParser(sparkSession, classPath) + input.withColumn("features", parser($"value")) + case AvroFormat(schemaJson) => + input.select(from_avro($"value", schemaJson).alias("features")) + } + + val projected = parsed .select("features.*") .select(projection: _*) diff --git a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala index 150c3c1dce..3766f26cd4 100644 --- a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala +++ b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala @@ -29,7 +29,12 @@ import feast.ingestion.helpers.DataHelper._ import feast.proto.storage.RedisProto.RedisKeyV2 import feast.proto.types.ValueProto -case class Row(customer: String, feature1: Int, feature2: Float, eventTimestamp: java.sql.Timestamp) +case class TestRow( + customer: String, + feature1: Int, + feature2: Float, + eventTimestamp: java.sql.Timestamp +) class BatchPipelineIT extends SparkSpec with ForAllTestContainer { @@ -51,9 +56,14 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer { eventTimestamp <- Gen .choose(0, Seconds.secondsBetween(start, end).getSeconds) .map(start.withMillisOfSecond(0).plusSeconds) - } yield Row(customer, feature1, feature2, new java.sql.Timestamp(eventTimestamp.getMillis)) + } yield TestRow( + customer, + feature1, + feature2, + new java.sql.Timestamp(eventTimestamp.getMillis) + ) - def encodeEntityKey(row: Row, featureTable: FeatureTable): Array[Byte] = { + def encodeEntityKey(row: TestRow, featureTable: FeatureTable): Array[Byte] = { RedisKeyV2 .newBuilder() .setProject(featureTable.project) @@ -63,7 +73,7 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer { .toByteArray } - def groupByEntity(row: Row) = + def groupByEntity(row: TestRow) = new String(encodeEntityKey(row, config.featureTable)) val config = IngestionJobConfig( diff --git a/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala b/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala index a6f199d327..1bea232548 100644 --- a/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala +++ b/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala @@ -38,12 +38,17 @@ import feast.ingestion.helpers.RedisStorageHelper._ import feast.ingestion.helpers.DataHelper._ import feast.proto.storage.RedisProto.RedisKeyV2 import feast.proto.types.ValueProto +import org.apache.spark.sql.Row +import org.apache.spark.sql.avro.to_avro +import org.apache.spark.sql.functions.{col, struct} +import org.apache.spark.sql.types.StructType class StreamingPipelineIT extends SparkSpec with ForAllTestContainer { val redisContainer = GenericContainer("redis:6.0.8", exposedPorts = Seq(6379)) val kafkaContainer = KafkaContainer() override val container = MultipleContainers(redisContainer, kafkaContainer) + override def withSparkConfOverrides(conf: SparkConf): SparkConf = conf .set("spark.redis.host", redisContainer.host) .set("spark.redis.port", redisContainer.mappedPort(6379).toString) @@ -112,7 +117,7 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer { val kafkaSource = KafkaSource( bootstrapServers = kafkaContainer.bootstrapServers, topic = "topic", - classpath = "com.example.protos.TestMessage", + format = ProtoFormat("com.example.protos.TestMessage"), fieldMapping = Map.empty, eventTimestampColumn = "event_timestamp" ) @@ -167,7 +172,7 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer { "All protobuf types" should "be correctly converted" in new Scope { val configWithKafka = config.copy( source = kafkaSource.copy( - classpath = "com.example.protos.AllTypesMessage", + format = ProtoFormat("com.example.protos.AllTypesMessage"), fieldMapping = Map( "map_value" -> "map.key", "inner_double" -> "inner.double", @@ -291,4 +296,75 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer { StreamingPipeline.createPipeline(sparkSession, configWithKafka).get } } + + "Streaming pipeline" should "store valid avro messages from kafka to redis" in new Scope { + val avroConfig = IngestionJobConfig( + featureTable = FeatureTable( + name = "test-fs", + project = "default", + entities = Seq(Field("customer", ValueType.Enum.STRING)), + features = Seq( + Field("feature1", ValueType.Enum.INT32), + Field("feature2", ValueType.Enum.FLOAT) + ) + ), + source = KafkaSource( + bootstrapServers = kafkaContainer.bootstrapServers, + topic = "avro", + format = AvroFormat(schemaJson = """{ + |"type": "record", + |"name": "TestMessage", + |"fields": [ + |{"name": "customer", "type": ["string","null"]}, + |{"name": "feature1", "type": "int"}, + |{"name": "feature2", "type": "float"}, + |{"name": "eventTimestamp", "type": [{"type": "long", "logicalType": "timestamp-micros"}, "null"]} + |] + |}""".stripMargin), + fieldMapping = Map.empty, + eventTimestampColumn = "eventTimestamp" + ) + ) + val query = StreamingPipeline.createPipeline(sparkSession, avroConfig).get + query.processAllAvailable() // to init kafka consumer + + val row = TestRow("aaa", 1, 0.5f, new java.sql.Timestamp(DateTime.now.withMillis(0).getMillis)) + val df = sparkSession.createDataFrame(Seq(row)) + df + .select( + to_avro( + struct( + col("customer"), + col("feature1"), + col("feature2"), + col("eventTimestamp") + ) + ).alias("value") + ) + .write + .format("kafka") + .option("kafka.bootstrap.servers", kafkaContainer.bootstrapServers) + .option("topic", "avro") + .save() + + query.processAllAvailable() + + val redisKey = RedisKeyV2 + .newBuilder() + .setProject("default") + .addEntityNames("customer") + .addEntityValues(ValueProto.Value.newBuilder().setStringVal("aaa")) + .build() + + val storedValues = jedis.hgetAll(redisKey.toByteArray).asScala.toMap + val customFeatureKeyEncoder: String => String = encodeFeatureKey(avroConfig.featureTable) + storedValues should beStoredRow( + Map( + customFeatureKeyEncoder("feature1") -> row.feature1, + customFeatureKeyEncoder("feature2") -> row.feature2, + "_ts:test-fs" -> row.eventTimestamp + ) + ) + + } } diff --git a/tests/e2e/requirements.txt b/tests/e2e/requirements.txt index 68595ee1b5..d8b1950926 100644 --- a/tests/e2e/requirements.txt +++ b/tests/e2e/requirements.txt @@ -8,7 +8,9 @@ pytest-mock==1.10.4 pytest-timeout==1.3.3 pytest-ordering==0.6.* pytest-xdist==2.1.0 -tensorflow-data-validation==0.21.2 +# tensorflow-data-validation==0.21.2 deepdiff==4.3.2 -tensorflow==2.1.0 -tfx-bsl==0.21.* # lock to 0.21 +# tensorflow==2.1.0 +# tfx-bsl==0.21.* # lock to 0.21 +confluent_kafka +avro==1.10.0 \ No newline at end of file diff --git a/tests/e2e/test_online_features.py b/tests/e2e/test_online_features.py index 552ad346cd..fc255c172b 100644 --- a/tests/e2e/test_online_features.py +++ b/tests/e2e/test_online_features.py @@ -1,16 +1,30 @@ +import io +import json import os import time import uuid from datetime import datetime, timedelta from pathlib import Path +import avro.schema import numpy as np import pandas as pd import pyspark import pytest - -from feast import Client, Entity, Feature, FeatureTable, FileSource, ValueType -from feast.data_format import ParquetFormat +import pytz +from avro.io import BinaryEncoder, DatumWriter +from confluent_kafka import Producer + +from feast import ( + Client, + Entity, + Feature, + FeatureTable, + FileSource, + KafkaSource, + ValueType, +) +from feast.data_format import AvroFormat, ParquetFormat from feast.pyspark.abc import SparkJobStatus from feast.wait import wait_retry_backoff @@ -110,11 +124,7 @@ def test_offline_ingestion(feast_client: Client, staging_path: str): feature_table, datetime.today(), datetime.today() + timedelta(days=1) ) - status = wait_retry_backoff( - lambda: (job.get_status(), job.get_status() != SparkJobStatus.IN_PROGRESS), 300 - ) - - assert status == SparkJobStatus.COMPLETED + wait_retry_backoff(lambda: (None, job.get_status() == SparkJobStatus.COMPLETED), 60) features = feast_client.get_online_features( ["drivers:unique_drivers"], @@ -128,3 +138,113 @@ def test_offline_ingestion(feast_client: Client, staging_path: str): columns={"unique_drivers": "drivers:unique_drivers"} ), ) + + +def test_streaming_ingestion(feast_client: Client, staging_path: str, pytestconfig): + entity = Entity(name="s2id", description="S2id", value_type=ValueType.INT64,) + + feature_table = FeatureTable( + name="drivers_stream", + entities=["s2id"], + features=[Feature("unique_drivers", ValueType.INT64)], + batch_source=FileSource( + "event_timestamp", + "event_timestamp", + ParquetFormat(), + os.path.join(staging_path, "batch-storage"), + ), + stream_source=KafkaSource( + "event_timestamp", + "event_timestamp", + pytestconfig.getoption("kafka_brokers"), + AvroFormat(avro_schema()), + topic="avro", + ), + ) + + feast_client.apply_entity(entity) + feast_client.apply_feature_table(feature_table) + + job = feast_client.start_stream_to_online_ingestion(feature_table) + + wait_retry_backoff( + lambda: (None, job.get_status() == SparkJobStatus.IN_PROGRESS), 60 + ) + + try: + original = generate_data()[["s2id", "unique_drivers", "event_timestamp"]] + for record in original.to_dict("records"): + record["event_timestamp"] = ( + record["event_timestamp"].to_pydatetime().replace(tzinfo=pytz.utc) + ) + + send_avro_record_to_kafka( + "avro", + record, + bootstrap_servers=pytestconfig.getoption("kafka_brokers"), + avro_schema_json=avro_schema(), + ) + + def get_online_features(): + features = feast_client.get_online_features( + ["drivers_stream:unique_drivers"], + entity_rows=[{"s2id": s2_id} for s2_id in original["s2id"].tolist()], + ).to_dict() + df = pd.DataFrame.from_dict(features) + return df, not df["drivers_stream:unique_drivers"].isna().any() + + ingested = wait_retry_backoff(get_online_features, 60) + finally: + job.cancel() + + pd.testing.assert_frame_equal( + ingested[["s2id", "drivers_stream:unique_drivers"]], + original[["s2id", "unique_drivers"]].rename( + columns={"unique_drivers": "drivers_stream:unique_drivers"} + ), + ) + + +def avro_schema(): + return json.dumps( + { + "type": "record", + "name": "TestMessage", + "fields": [ + {"name": "s2id", "type": "long"}, + {"name": "unique_drivers", "type": "long"}, + { + "name": "event_timestamp", + "type": {"type": "long", "logicalType": "timestamp-micros"}, + }, + ], + } + ) + + +def send_avro_record_to_kafka(topic, value, bootstrap_servers, avro_schema_json): + value_schema = avro.schema.parse(avro_schema_json) + + producer_config = { + "bootstrap.servers": bootstrap_servers, + "request.timeout.ms": "1000", + } + + producer = Producer(producer_config) + + writer = DatumWriter(value_schema) + bytes_writer = io.BytesIO() + encoder = BinaryEncoder(bytes_writer) + + writer.write(value, encoder) + + try: + producer.produce(topic=topic, value=bytes_writer.getvalue()) + except Exception as e: + print( + f"Exception while producing record value - {value} to topic - {topic}: {e}" + ) + else: + print(f"Successfully producing record value - {value} to topic - {topic}") + + producer.flush()