diff --git a/.github/workflows/master_only.yml b/.github/workflows/master_only.yml
index 2852c559dd..20caebe0af 100644
--- a/.github/workflows/master_only.yml
+++ b/.github/workflows/master_only.yml
@@ -79,7 +79,15 @@ jobs:
- uses: stCarolas/setup-maven@v3
with:
maven-version: 3.6.3
- - name: build-jar
+ - name: Publish develop version of ingestion job
+ run: |
+ if [ ${GITHUB_REF#refs/*/} == "master" ]; then
+ make build-java-no-tests REVISION=develop
+ gsutil cp ./spark/ingestion/target/feast-ingestion-spark-develop.jar gs://${PUBLISH_BUCKET}/spark/ingestion/
+ fi
+ - name: Get version
+ run: echo ::set-env name=RELEASE_VERSION::${GITHUB_REF#refs/*/}
+ - name: Publish tagged version of ingestion job
run: |
SEMVER_REGEX='^v[0-9]+\.[0-9]+\.[0-9]+(-([0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*))?$'
if echo "${RELEASE_VERSION}" | grep -P "$SEMVER_REGEX" &>/dev/null ; then
diff --git a/core/src/main/java/feast/core/validators/DataSourceValidator.java b/core/src/main/java/feast/core/validators/DataSourceValidator.java
index 223906d0e6..f36e3609d7 100644
--- a/core/src/main/java/feast/core/validators/DataSourceValidator.java
+++ b/core/src/main/java/feast/core/validators/DataSourceValidator.java
@@ -51,6 +51,8 @@ public static void validate(DataSource spec) {
spec.getKafkaOptions().getMessageFormat().getProtoFormat().getClassPath(),
"FeatureTable");
break;
+ case AVRO_FORMAT:
+ break;
default:
throw new UnsupportedOperationException(
String.format(
@@ -68,6 +70,8 @@ public static void validate(DataSource spec) {
spec.getKinesisOptions().getRecordFormat().getProtoFormat().getClassPath(),
"FeatureTable");
break;
+ case AVRO_FORMAT:
+ break;
default:
throw new UnsupportedOperationException(
String.format("Unsupported Stream Format for Kafka Source Type: %s", recordFormat));
diff --git a/infra/charts/feast/values.yaml b/infra/charts/feast/values.yaml
index b3a2df67be..a42349c376 100644
--- a/infra/charts/feast/values.yaml
+++ b/infra/charts/feast/values.yaml
@@ -4,7 +4,7 @@ feast-core:
feast-jobcontroller:
# feast-jobcontroller.enabled -- Flag to install Feast Job Controller
- enabled: true
+ enabled: false
feast-online-serving:
# feast-online-serving.enabled -- Flag to install Feast Online Serving
diff --git a/sdk/python/feast/constants.py b/sdk/python/feast/constants.py
index e55c748211..071e867d4c 100644
--- a/sdk/python/feast/constants.py
+++ b/sdk/python/feast/constants.py
@@ -123,7 +123,7 @@ class AuthProvider(Enum):
# Authentication Provider - Google OpenID/OAuth
CONFIG_AUTH_PROVIDER: "google",
CONFIG_SPARK_LAUNCHER: "dataproc",
- CONFIG_SPARK_INGESTION_JOB_JAR: "gs://feast-jobs/feast-ingestion-spark-0.8-SNAPSHOT.jar",
+ CONFIG_SPARK_INGESTION_JOB_JAR: "gs://feast-jobs/feast-ingestion-spark-develop.jar",
CONFIG_REDIS_HOST: "localhost",
CONFIG_REDIS_PORT: "6379",
CONFIG_REDIS_SSL: "False",
diff --git a/sdk/python/feast/pyspark/abc.py b/sdk/python/feast/pyspark/abc.py
index ed65672950..45005e20d1 100644
--- a/sdk/python/feast/pyspark/abc.py
+++ b/sdk/python/feast/pyspark/abc.py
@@ -19,6 +19,7 @@ class SparkJobFailure(Exception):
class SparkJobStatus(Enum):
+ STARTING = 0
IN_PROGRESS = 1
FAILED = 2
COMPLETED = 3
@@ -48,6 +49,13 @@ def get_status(self) -> SparkJobStatus:
"""
raise NotImplementedError
+ @abc.abstractmethod
+ def cancel(self):
+ """
+ Manually terminate job
+ """
+ raise NotImplementedError
+
class SparkJobParameters(abc.ABC):
@abc.abstractmethod
diff --git a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py
index 1d221e0f32..3de7ec660f 100644
--- a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py
+++ b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py
@@ -75,6 +75,12 @@ class FileSource(Source):
options (Optional[Dict[str, str]]): Options to be passed to spark while reading the file source.
"""
+ PROTO_FORMAT_TO_SPARK = {
+ "ParquetFormat": "parquet",
+ "AvroFormat": "avro",
+ "CSVFormat": "csv",
+ }
+
def __init__(
self,
format: str,
@@ -147,7 +153,7 @@ def spark_path(self) -> str:
def _source_from_dict(dct: Dict) -> Source:
if "file" in dct.keys():
return FileSource(
- dct["file"]["format"],
+ FileSource.PROTO_FORMAT_TO_SPARK[dct["file"]["format"]["json_class"]],
dct["file"]["path"],
dct["file"]["event_timestamp_column"],
dct["file"].get("created_timestamp_column"),
@@ -635,7 +641,7 @@ def retrieve_historical_features(
Example:
>>> entity_source_conf = {
- "format": "csv",
+ "format": {"jsonClass": "ParquetFormat"},
"path": "file:///some_dir/customer_driver_pairs.csv"),
"options": {"inferSchema": "true", "header": "true"},
"field_mapping": {"id": "driver_id"}
@@ -643,12 +649,12 @@ def retrieve_historical_features(
>>> feature_tables_sources_conf = [
{
- "format": "parquet",
+ "format": {"json_class": "ParquetFormat"},
"path": "gs://some_bucket/bookings.parquet"),
"field_mapping": {"id": "driver_id"}
},
{
- "format": "avro",
+ "format": {"json_class": "AvroFormat", schema_json: "..avro schema.."},
"path": "s3://some_bucket/transactions.avro"),
}
]
diff --git a/sdk/python/feast/pyspark/launcher.py b/sdk/python/feast/pyspark/launcher.py
index 5156a7c541..defd181442 100644
--- a/sdk/python/feast/pyspark/launcher.py
+++ b/sdk/python/feast/pyspark/launcher.py
@@ -86,9 +86,6 @@ def resolve_launcher(config: Config) -> JobLauncher:
return _launchers[config.get(CONFIG_SPARK_LAUNCHER)](config)
-_SOURCES = {FileSource: "file", BigQuerySource: "bq", KafkaSource: "kafka"}
-
-
def _source_to_argument(source: DataSource):
common_properties = {
"field_mapping": dict(source.field_mapping),
@@ -97,20 +94,28 @@ def _source_to_argument(source: DataSource):
"date_partition_column": source.date_partition_column,
}
- kind = _SOURCES[type(source)]
properties = {**common_properties}
+
if isinstance(source, FileSource):
properties["path"] = source.file_options.file_url
- properties["format"] = str(source.file_options.file_format)
- return {kind: properties}
+ properties["format"] = dict(
+ json_class=source.file_options.file_format.__class__.__name__
+ )
+ return {"file": properties}
+
if isinstance(source, BigQuerySource):
properties["table_ref"] = source.bigquery_options.table_ref
- return {kind: properties}
+ return {"bq": properties}
+
if isinstance(source, KafkaSource):
- properties["topic"] = source.kafka_options.topic
- properties["classpath"] = source.kafka_options.class_path
properties["bootstrap_servers"] = source.kafka_options.bootstrap_servers
- return {kind: properties}
+ properties["topic"] = source.kafka_options.topic
+ properties["format"] = {
+ **source.kafka_options.message_format.__dict__,
+ "json_class": source.kafka_options.message_format.__class__.__name__,
+ }
+ return {"kafka": properties}
+
raise NotImplementedError(f"Unsupported Datasource: {type(source)}")
diff --git a/sdk/python/feast/pyspark/launchers/aws/emr.py b/sdk/python/feast/pyspark/launchers/aws/emr.py
index 3b87f91672..c23d07fb59 100644
--- a/sdk/python/feast/pyspark/launchers/aws/emr.py
+++ b/sdk/python/feast/pyspark/launchers/aws/emr.py
@@ -62,6 +62,9 @@ def get_status(self) -> SparkJobStatus:
# we should never get here
raise Exception("Invalid EMR state")
+ def cancel(self):
+ raise NotImplementedError
+
class EmrRetrievalJob(EmrJobMixin, RetrievalJob):
"""
diff --git a/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py b/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py
index 1d6a0585c2..ff314194c9 100644
--- a/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py
+++ b/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py
@@ -45,6 +45,9 @@ def get_status(self) -> SparkJobStatus:
return SparkJobStatus.FAILED
+ def cancel(self):
+ self._operation.cancel()
+
class DataprocRetrievalJob(DataprocJobMixin, RetrievalJob):
"""
@@ -71,7 +74,13 @@ def get_output_file_uri(self, timeout_sec=None):
class DataprocBatchIngestionJob(DataprocJobMixin, BatchIngestionJob):
"""
- Ingestion job result for a Dataproc cluster
+ Batch Ingestion job result for a Dataproc cluster
+ """
+
+
+class DataprocStreamingIngestionJob(DataprocJobMixin, StreamIngestionJob):
+ """
+ Streaming Ingestion job result for a Dataproc cluster
"""
@@ -151,14 +160,14 @@ def historical_feature_retrieval(
)
def offline_to_online_ingestion(
- self, job_params: BatchIngestionJobParameters
+ self, ingestion_job_params: BatchIngestionJobParameters
) -> BatchIngestionJob:
- return DataprocBatchIngestionJob(self.dataproc_submit(job_params))
+ return DataprocBatchIngestionJob(self.dataproc_submit(ingestion_job_params))
def start_stream_to_online_ingestion(
self, ingestion_job_params: StreamIngestionJobParameters
) -> StreamIngestionJob:
- raise NotImplementedError
+ return DataprocStreamingIngestionJob(self.dataproc_submit(ingestion_job_params))
def stage_dataframe(
self, df, event_timestamp_column: str, created_timestamp_column: str,
diff --git a/sdk/python/feast/pyspark/launchers/standalone/local.py b/sdk/python/feast/pyspark/launchers/standalone/local.py
index 2ed1c26a08..fdbbb45064 100644
--- a/sdk/python/feast/pyspark/launchers/standalone/local.py
+++ b/sdk/python/feast/pyspark/launchers/standalone/local.py
@@ -1,6 +1,11 @@
import os
+import socket
import subprocess
import uuid
+from contextlib import closing
+
+import requests
+from requests.exceptions import RequestException
from feast.pyspark.abc import (
BatchIngestionJob,
@@ -16,17 +21,53 @@
)
+def _find_free_port():
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.bind(("", 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return s.getsockname()[1]
+
+
class StandaloneClusterJobMixin:
- def __init__(self, job_id: str, process: subprocess.Popen):
+ def __init__(
+ self, job_id: str, job_name: str, process: subprocess.Popen, ui_port: int = None
+ ):
self._job_id = job_id
+ self._job_name = job_name
self._process = process
+ self._ui_port = ui_port
def get_id(self) -> str:
return self._job_id
+ def check_if_started(self):
+ if not self._ui_port:
+ return True
+
+ try:
+ applications = requests.get(
+ f"http://localhost:{self._ui_port}/api/v1/applications"
+ ).json()
+ except RequestException:
+ return False
+
+ app = next(
+ iter(app for app in applications if app["name"] == self._job_name), None
+ )
+ if not app:
+ return False
+
+ stages = requests.get(
+ f"http://localhost:{self._ui_port}/api/v1/applications/{app['id']}/stages"
+ ).json()
+ return bool(stages)
+
def get_status(self) -> SparkJobStatus:
code = self._process.poll()
if code is None:
+ if not self.check_if_started():
+ return SparkJobStatus.STARTING
+
return SparkJobStatus.IN_PROGRESS
if code != 0:
@@ -34,10 +75,23 @@ def get_status(self) -> SparkJobStatus:
return SparkJobStatus.COMPLETED
+ def cancel(self):
+ self._process.terminate()
+
class StandaloneClusterBatchIngestionJob(StandaloneClusterJobMixin, BatchIngestionJob):
"""
- Ingestion job result for a standalone spark cluster
+ Batch Ingestion job result for a standalone spark cluster
+ """
+
+ pass
+
+
+class StandaloneClusterStreamingIngestionJob(
+ StandaloneClusterJobMixin, StreamIngestionJob
+):
+ """
+ Streaming Ingestion job result for a standalone spark cluster
"""
pass
@@ -48,7 +102,13 @@ class StandaloneClusterRetrievalJob(StandaloneClusterJobMixin, RetrievalJob):
Historical feature retrieval job result for a standalone spark cluster
"""
- def __init__(self, job_id: str, process: subprocess.Popen, output_file_uri: str):
+ def __init__(
+ self,
+ job_id: str,
+ job_name: str,
+ process: subprocess.Popen,
+ output_file_uri: str,
+ ):
"""
This is the returned historical feature retrieval job result for StandaloneClusterLauncher.
@@ -57,7 +117,7 @@ def __init__(self, job_id: str, process: subprocess.Popen, output_file_uri: str)
process (subprocess.Popen): Pyspark driver process, spawned by the launcher.
output_file_uri (str): Uri to the historical feature retrieval job output file.
"""
- super().__init__(job_id, process)
+ super().__init__(job_id, job_name, process)
self._output_file_uri = output_file_uri
def get_output_file_uri(self, timeout_sec: int = None):
@@ -100,7 +160,9 @@ def __init__(self, master_url: str, spark_home: str = None):
def spark_submit_script_path(self):
return os.path.join(self.spark_home, "bin/spark-submit")
- def spark_submit(self, job_params: SparkJobParameters) -> subprocess.Popen:
+ def spark_submit(
+ self, job_params: SparkJobParameters, ui_port: int = None
+ ) -> subprocess.Popen:
submission_cmd = [
self.spark_submit_script_path,
"--master",
@@ -112,6 +174,9 @@ def spark_submit(self, job_params: SparkJobParameters) -> subprocess.Popen:
if job_params.get_class_name():
submission_cmd.extend(["--class", job_params.get_class_name()])
+ if ui_port:
+ submission_cmd.extend(["--conf", f"spark.ui.port={ui_port}"])
+
submission_cmd.append(job_params.get_main_file_path())
submission_cmd.extend(job_params.get_arguments())
@@ -122,19 +187,35 @@ def historical_feature_retrieval(
) -> RetrievalJob:
job_id = str(uuid.uuid4())
return StandaloneClusterRetrievalJob(
- job_id, self.spark_submit(job_params), job_params.get_destination_path()
+ job_id,
+ job_params.get_name(),
+ self.spark_submit(job_params),
+ job_params.get_destination_path(),
)
def offline_to_online_ingestion(
- self, job_params: BatchIngestionJobParameters
+ self, ingestion_job_params: BatchIngestionJobParameters
) -> BatchIngestionJob:
job_id = str(uuid.uuid4())
- return StandaloneClusterBatchIngestionJob(job_id, self.spark_submit(job_params))
+ ui_port = _find_free_port()
+ return StandaloneClusterBatchIngestionJob(
+ job_id,
+ ingestion_job_params.get_name(),
+ self.spark_submit(ingestion_job_params, ui_port),
+ ui_port,
+ )
def start_stream_to_online_ingestion(
self, ingestion_job_params: StreamIngestionJobParameters
) -> StreamIngestionJob:
- raise NotImplementedError
+ job_id = str(uuid.uuid4())
+ ui_port = _find_free_port()
+ return StandaloneClusterStreamingIngestionJob(
+ job_id,
+ ingestion_job_params.get_name(),
+ self.spark_submit(ingestion_job_params, ui_port),
+ ui_port,
+ )
def stage_dataframe(
self, df, event_timestamp_column: str, created_timestamp_column: str,
diff --git a/sdk/python/requirements-ci.txt b/sdk/python/requirements-ci.txt
index 62b335e557..eaeb81387e 100644
--- a/sdk/python/requirements-ci.txt
+++ b/sdk/python/requirements-ci.txt
@@ -16,3 +16,5 @@ pandavro==1.5.*
moto
mypy
mypy-protobuf
+avro==1.10.0
+confluent_kafka
diff --git a/sdk/python/requirements-dev.txt b/sdk/python/requirements-dev.txt
index ac7813ec79..f34c0c8924 100644
--- a/sdk/python/requirements-dev.txt
+++ b/sdk/python/requirements-dev.txt
@@ -26,7 +26,6 @@ Sphinx
sphinx-rtd-theme
toml==0.10.*
tqdm==4.*
-confluent_kafka
google
pandavro==1.5.*
kafka-python==1.*
diff --git a/sdk/python/tests/test_as_of_join.py b/sdk/python/tests/test_as_of_join.py
index a6288855c7..23ce94e176 100644
--- a/sdk/python/tests/test_as_of_join.py
+++ b/sdk/python/tests/test_as_of_join.py
@@ -582,7 +582,7 @@ def test_historical_feature_retrieval(spark: SparkSession):
test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data")
entity_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}",
"event_timestamp_column": "event_timestamp",
"options": {"inferSchema": "true", "header": "true"},
@@ -590,7 +590,7 @@ def test_historical_feature_retrieval(spark: SparkSession):
}
booking_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'bookings.csv')}",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp",
@@ -599,7 +599,7 @@ def test_historical_feature_retrieval(spark: SparkSession):
}
transaction_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'transactions.csv')}",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp",
@@ -653,7 +653,7 @@ def test_historical_feature_retrieval_with_mapping(spark: SparkSession):
test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data")
entity_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'column_mapping_test_entity.csv')}",
"event_timestamp_column": "event_timestamp",
"field_mapping": {"id": "customer_id"},
@@ -662,7 +662,7 @@ def test_historical_feature_retrieval_with_mapping(spark: SparkSession):
}
booking_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'column_mapping_test_feature.csv')}",
"event_timestamp_column": "datetime",
"created_timestamp_column": "created_datetime",
@@ -723,7 +723,7 @@ def test_large_historical_feature_retrieval(
entity_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{large_entity_csv_file}",
"event_timestamp_column": "event_timestamp",
"field_mapping": {"id": "customer_id"},
@@ -732,7 +732,7 @@ def test_large_historical_feature_retrieval(
}
feature_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{large_feature_csv_file}",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp",
@@ -755,7 +755,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession):
test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data")
entity_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}",
"event_timestamp_column": "event_timestamp",
"options": {"inferSchema": "true", "header": "true"},
@@ -763,7 +763,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession):
}
entity_source_missing_timestamp = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}",
"event_timestamp_column": "datetime",
"options": {"inferSchema": "true", "header": "true"},
@@ -771,7 +771,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession):
}
entity_source_missing_entity = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'customers.csv')}",
"event_timestamp_column": "event_timestamp",
"options": {"inferSchema": "true", "header": "true"},
@@ -780,7 +780,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession):
booking_source = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'bookings.csv')}",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp",
@@ -789,7 +789,7 @@ def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession):
}
booking_source_missing_timestamp = {
"file": {
- "format": "csv",
+ "format": {"json_class": "CSVFormat"},
"path": f"file://{path.join(test_data_dir, 'bookings.csv')}",
"event_timestamp_column": "datetime",
"created_timestamp_column": "created_datetime",
diff --git a/spark/ingestion/pom.xml b/spark/ingestion/pom.xml
index 74aab0f3db..6f51489086 100644
--- a/spark/ingestion/pom.xml
+++ b/spark/ingestion/pom.xml
@@ -109,6 +109,12 @@
${spark.version}
+
+ org.apache.spark
+ spark-avro_${scala.version}
+ ${spark.version}
+
+
org.apache.kafka
kafka-clients
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala
index 91b274d1ef..d958c2ea8d 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala
@@ -28,9 +28,6 @@ trait BasePipeline {
System.setProperty("io.netty.tryReflectionSetAccessible", "true")
val conf = new SparkConf()
- conf
- .setAppName(s"${jobConfig.mode} IngestionJob for ${jobConfig.featureTable.name}")
- .setMaster("local")
jobConfig.store match {
case RedisConfig(host, port, ssl) =>
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala
index 82168dbe81..88cc2b8b8e 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala
@@ -26,7 +26,8 @@ object IngestionJob {
import Modes._
implicit val modesRead: scopt.Read[Modes.Value] = scopt.Read.reads(Modes withName _.capitalize)
implicit val formats: Formats = DefaultFormats +
- new JavaEnumNameSerializer[feast.proto.types.ValueProto.ValueType.Enum]()
+ new JavaEnumNameSerializer[feast.proto.types.ValueProto.ValueType.Enum]() +
+ ShortTypeHints(List(classOf[ProtoFormat], classOf[AvroFormat]))
val parser = new scopt.OptionParser[IngestionJobConfig]("IngestionJon") {
// ToDo: read version from Manifest
@@ -75,6 +76,7 @@ object IngestionJob {
def main(args: Array[String]): Unit = {
parser.parse(args, IngestionJobConfig()) match {
case Some(config) =>
+ println(s"Starting with config $config")
config.mode match {
case Modes.Offline =>
val sparkSession = BatchPipeline.createSparkSession(config)
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala
index f13298ef2d..9ca3612180 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala
@@ -32,6 +32,11 @@ abstract class MetricConfig
case class StatsDConfig(host: String, port: Int) extends MetricConfig
+abstract class DataFormat
+case class ParquetFormat() extends DataFormat
+case class ProtoFormat(classPath: String) extends DataFormat
+case class AvroFormat(schemaJson: String) extends DataFormat
+
abstract class Source {
def fieldMapping: Map[String, String]
@@ -43,7 +48,7 @@ abstract class Source {
abstract class BatchSource extends Source
abstract class StreamingSource extends Source {
- def classpath: String
+ def format: DataFormat
}
case class FileSource(
@@ -67,7 +72,7 @@ case class BQSource(
case class KafkaSource(
bootstrapServers: String,
topic: String,
- override val classpath: String,
+ override val format: DataFormat,
override val fieldMapping: Map[String, String],
override val eventTimestampColumn: String,
override val createdTimestampColumn: Option[String] = None,
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
index 576849f4b2..07fd20f2c1 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.functions.udf
import feast.ingestion.utils.ProtoReflection
import feast.ingestion.validation.{RowValidator, TypeCheck}
import org.apache.spark.sql.streaming.StreamingQuery
+import org.apache.spark.sql.avro._
/**
* Streaming pipeline (currently in micro-batches mode only, since we need to have multiple sinks: redis & deadletters).
@@ -44,9 +45,6 @@ object StreamingPipeline extends BasePipeline with Serializable {
inputProjection(config.source, featureTable.features, featureTable.entities)
val validator = new RowValidator(featureTable, config.source.eventTimestampColumn)
- val messageParser =
- protoParser(sparkSession, config.source.asInstanceOf[StreamingSource].classpath)
-
val input = config.source match {
case source: KafkaSource =>
sparkSession.readStream
@@ -56,8 +54,15 @@ object StreamingPipeline extends BasePipeline with Serializable {
.load()
}
- val projected = input
- .withColumn("features", messageParser($"value"))
+ val parsed = config.source.asInstanceOf[StreamingSource].format match {
+ case ProtoFormat(classPath) =>
+ val parser = protoParser(sparkSession, classPath)
+ input.withColumn("features", parser($"value"))
+ case AvroFormat(schemaJson) =>
+ input.select(from_avro($"value", schemaJson).alias("features"))
+ }
+
+ val projected = parsed
.select("features.*")
.select(projection: _*)
diff --git a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala
index 150c3c1dce..3766f26cd4 100644
--- a/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala
+++ b/spark/ingestion/src/test/scala/feast/ingestion/BatchPipelineIT.scala
@@ -29,7 +29,12 @@ import feast.ingestion.helpers.DataHelper._
import feast.proto.storage.RedisProto.RedisKeyV2
import feast.proto.types.ValueProto
-case class Row(customer: String, feature1: Int, feature2: Float, eventTimestamp: java.sql.Timestamp)
+case class TestRow(
+ customer: String,
+ feature1: Int,
+ feature2: Float,
+ eventTimestamp: java.sql.Timestamp
+)
class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
@@ -51,9 +56,14 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
eventTimestamp <- Gen
.choose(0, Seconds.secondsBetween(start, end).getSeconds)
.map(start.withMillisOfSecond(0).plusSeconds)
- } yield Row(customer, feature1, feature2, new java.sql.Timestamp(eventTimestamp.getMillis))
+ } yield TestRow(
+ customer,
+ feature1,
+ feature2,
+ new java.sql.Timestamp(eventTimestamp.getMillis)
+ )
- def encodeEntityKey(row: Row, featureTable: FeatureTable): Array[Byte] = {
+ def encodeEntityKey(row: TestRow, featureTable: FeatureTable): Array[Byte] = {
RedisKeyV2
.newBuilder()
.setProject(featureTable.project)
@@ -63,7 +73,7 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
.toByteArray
}
- def groupByEntity(row: Row) =
+ def groupByEntity(row: TestRow) =
new String(encodeEntityKey(row, config.featureTable))
val config = IngestionJobConfig(
diff --git a/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala b/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala
index a6f199d327..1bea232548 100644
--- a/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala
+++ b/spark/ingestion/src/test/scala/feast/ingestion/StreamingPipelineIT.scala
@@ -38,12 +38,17 @@ import feast.ingestion.helpers.RedisStorageHelper._
import feast.ingestion.helpers.DataHelper._
import feast.proto.storage.RedisProto.RedisKeyV2
import feast.proto.types.ValueProto
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.avro.to_avro
+import org.apache.spark.sql.functions.{col, struct}
+import org.apache.spark.sql.types.StructType
class StreamingPipelineIT extends SparkSpec with ForAllTestContainer {
val redisContainer = GenericContainer("redis:6.0.8", exposedPorts = Seq(6379))
val kafkaContainer = KafkaContainer()
override val container = MultipleContainers(redisContainer, kafkaContainer)
+
override def withSparkConfOverrides(conf: SparkConf): SparkConf = conf
.set("spark.redis.host", redisContainer.host)
.set("spark.redis.port", redisContainer.mappedPort(6379).toString)
@@ -112,7 +117,7 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer {
val kafkaSource = KafkaSource(
bootstrapServers = kafkaContainer.bootstrapServers,
topic = "topic",
- classpath = "com.example.protos.TestMessage",
+ format = ProtoFormat("com.example.protos.TestMessage"),
fieldMapping = Map.empty,
eventTimestampColumn = "event_timestamp"
)
@@ -167,7 +172,7 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer {
"All protobuf types" should "be correctly converted" in new Scope {
val configWithKafka = config.copy(
source = kafkaSource.copy(
- classpath = "com.example.protos.AllTypesMessage",
+ format = ProtoFormat("com.example.protos.AllTypesMessage"),
fieldMapping = Map(
"map_value" -> "map.key",
"inner_double" -> "inner.double",
@@ -291,4 +296,75 @@ class StreamingPipelineIT extends SparkSpec with ForAllTestContainer {
StreamingPipeline.createPipeline(sparkSession, configWithKafka).get
}
}
+
+ "Streaming pipeline" should "store valid avro messages from kafka to redis" in new Scope {
+ val avroConfig = IngestionJobConfig(
+ featureTable = FeatureTable(
+ name = "test-fs",
+ project = "default",
+ entities = Seq(Field("customer", ValueType.Enum.STRING)),
+ features = Seq(
+ Field("feature1", ValueType.Enum.INT32),
+ Field("feature2", ValueType.Enum.FLOAT)
+ )
+ ),
+ source = KafkaSource(
+ bootstrapServers = kafkaContainer.bootstrapServers,
+ topic = "avro",
+ format = AvroFormat(schemaJson = """{
+ |"type": "record",
+ |"name": "TestMessage",
+ |"fields": [
+ |{"name": "customer", "type": ["string","null"]},
+ |{"name": "feature1", "type": "int"},
+ |{"name": "feature2", "type": "float"},
+ |{"name": "eventTimestamp", "type": [{"type": "long", "logicalType": "timestamp-micros"}, "null"]}
+ |]
+ |}""".stripMargin),
+ fieldMapping = Map.empty,
+ eventTimestampColumn = "eventTimestamp"
+ )
+ )
+ val query = StreamingPipeline.createPipeline(sparkSession, avroConfig).get
+ query.processAllAvailable() // to init kafka consumer
+
+ val row = TestRow("aaa", 1, 0.5f, new java.sql.Timestamp(DateTime.now.withMillis(0).getMillis))
+ val df = sparkSession.createDataFrame(Seq(row))
+ df
+ .select(
+ to_avro(
+ struct(
+ col("customer"),
+ col("feature1"),
+ col("feature2"),
+ col("eventTimestamp")
+ )
+ ).alias("value")
+ )
+ .write
+ .format("kafka")
+ .option("kafka.bootstrap.servers", kafkaContainer.bootstrapServers)
+ .option("topic", "avro")
+ .save()
+
+ query.processAllAvailable()
+
+ val redisKey = RedisKeyV2
+ .newBuilder()
+ .setProject("default")
+ .addEntityNames("customer")
+ .addEntityValues(ValueProto.Value.newBuilder().setStringVal("aaa"))
+ .build()
+
+ val storedValues = jedis.hgetAll(redisKey.toByteArray).asScala.toMap
+ val customFeatureKeyEncoder: String => String = encodeFeatureKey(avroConfig.featureTable)
+ storedValues should beStoredRow(
+ Map(
+ customFeatureKeyEncoder("feature1") -> row.feature1,
+ customFeatureKeyEncoder("feature2") -> row.feature2,
+ "_ts:test-fs" -> row.eventTimestamp
+ )
+ )
+
+ }
}
diff --git a/tests/e2e/requirements.txt b/tests/e2e/requirements.txt
index 68595ee1b5..d8b1950926 100644
--- a/tests/e2e/requirements.txt
+++ b/tests/e2e/requirements.txt
@@ -8,7 +8,9 @@ pytest-mock==1.10.4
pytest-timeout==1.3.3
pytest-ordering==0.6.*
pytest-xdist==2.1.0
-tensorflow-data-validation==0.21.2
+# tensorflow-data-validation==0.21.2
deepdiff==4.3.2
-tensorflow==2.1.0
-tfx-bsl==0.21.* # lock to 0.21
+# tensorflow==2.1.0
+# tfx-bsl==0.21.* # lock to 0.21
+confluent_kafka
+avro==1.10.0
\ No newline at end of file
diff --git a/tests/e2e/test_online_features.py b/tests/e2e/test_online_features.py
index 552ad346cd..fc255c172b 100644
--- a/tests/e2e/test_online_features.py
+++ b/tests/e2e/test_online_features.py
@@ -1,16 +1,30 @@
+import io
+import json
import os
import time
import uuid
from datetime import datetime, timedelta
from pathlib import Path
+import avro.schema
import numpy as np
import pandas as pd
import pyspark
import pytest
-
-from feast import Client, Entity, Feature, FeatureTable, FileSource, ValueType
-from feast.data_format import ParquetFormat
+import pytz
+from avro.io import BinaryEncoder, DatumWriter
+from confluent_kafka import Producer
+
+from feast import (
+ Client,
+ Entity,
+ Feature,
+ FeatureTable,
+ FileSource,
+ KafkaSource,
+ ValueType,
+)
+from feast.data_format import AvroFormat, ParquetFormat
from feast.pyspark.abc import SparkJobStatus
from feast.wait import wait_retry_backoff
@@ -110,11 +124,7 @@ def test_offline_ingestion(feast_client: Client, staging_path: str):
feature_table, datetime.today(), datetime.today() + timedelta(days=1)
)
- status = wait_retry_backoff(
- lambda: (job.get_status(), job.get_status() != SparkJobStatus.IN_PROGRESS), 300
- )
-
- assert status == SparkJobStatus.COMPLETED
+ wait_retry_backoff(lambda: (None, job.get_status() == SparkJobStatus.COMPLETED), 60)
features = feast_client.get_online_features(
["drivers:unique_drivers"],
@@ -128,3 +138,113 @@ def test_offline_ingestion(feast_client: Client, staging_path: str):
columns={"unique_drivers": "drivers:unique_drivers"}
),
)
+
+
+def test_streaming_ingestion(feast_client: Client, staging_path: str, pytestconfig):
+ entity = Entity(name="s2id", description="S2id", value_type=ValueType.INT64,)
+
+ feature_table = FeatureTable(
+ name="drivers_stream",
+ entities=["s2id"],
+ features=[Feature("unique_drivers", ValueType.INT64)],
+ batch_source=FileSource(
+ "event_timestamp",
+ "event_timestamp",
+ ParquetFormat(),
+ os.path.join(staging_path, "batch-storage"),
+ ),
+ stream_source=KafkaSource(
+ "event_timestamp",
+ "event_timestamp",
+ pytestconfig.getoption("kafka_brokers"),
+ AvroFormat(avro_schema()),
+ topic="avro",
+ ),
+ )
+
+ feast_client.apply_entity(entity)
+ feast_client.apply_feature_table(feature_table)
+
+ job = feast_client.start_stream_to_online_ingestion(feature_table)
+
+ wait_retry_backoff(
+ lambda: (None, job.get_status() == SparkJobStatus.IN_PROGRESS), 60
+ )
+
+ try:
+ original = generate_data()[["s2id", "unique_drivers", "event_timestamp"]]
+ for record in original.to_dict("records"):
+ record["event_timestamp"] = (
+ record["event_timestamp"].to_pydatetime().replace(tzinfo=pytz.utc)
+ )
+
+ send_avro_record_to_kafka(
+ "avro",
+ record,
+ bootstrap_servers=pytestconfig.getoption("kafka_brokers"),
+ avro_schema_json=avro_schema(),
+ )
+
+ def get_online_features():
+ features = feast_client.get_online_features(
+ ["drivers_stream:unique_drivers"],
+ entity_rows=[{"s2id": s2_id} for s2_id in original["s2id"].tolist()],
+ ).to_dict()
+ df = pd.DataFrame.from_dict(features)
+ return df, not df["drivers_stream:unique_drivers"].isna().any()
+
+ ingested = wait_retry_backoff(get_online_features, 60)
+ finally:
+ job.cancel()
+
+ pd.testing.assert_frame_equal(
+ ingested[["s2id", "drivers_stream:unique_drivers"]],
+ original[["s2id", "unique_drivers"]].rename(
+ columns={"unique_drivers": "drivers_stream:unique_drivers"}
+ ),
+ )
+
+
+def avro_schema():
+ return json.dumps(
+ {
+ "type": "record",
+ "name": "TestMessage",
+ "fields": [
+ {"name": "s2id", "type": "long"},
+ {"name": "unique_drivers", "type": "long"},
+ {
+ "name": "event_timestamp",
+ "type": {"type": "long", "logicalType": "timestamp-micros"},
+ },
+ ],
+ }
+ )
+
+
+def send_avro_record_to_kafka(topic, value, bootstrap_servers, avro_schema_json):
+ value_schema = avro.schema.parse(avro_schema_json)
+
+ producer_config = {
+ "bootstrap.servers": bootstrap_servers,
+ "request.timeout.ms": "1000",
+ }
+
+ producer = Producer(producer_config)
+
+ writer = DatumWriter(value_schema)
+ bytes_writer = io.BytesIO()
+ encoder = BinaryEncoder(bytes_writer)
+
+ writer.write(value, encoder)
+
+ try:
+ producer.produce(topic=topic, value=bytes_writer.getvalue())
+ except Exception as e:
+ print(
+ f"Exception while producing record value - {value} to topic - {topic}: {e}"
+ )
+ else:
+ print(f"Successfully producing record value - {value} to topic - {topic}")
+
+ producer.flush()