diff --git a/storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java b/storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java
new file mode 100644
index 0000000000..296582f8b3
--- /dev/null
+++ b/storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java
@@ -0,0 +1,58 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.common.retry;
+
+import java.io.Serializable;
+import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.sdk.util.BackOffUtils;
+import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.sdk.util.Sleeper;
+import org.joda.time.Duration;
+
+public class BackOffExecutor implements Serializable {
+
+ private final Integer maxRetries;
+ private final Duration initialBackOff;
+
+ public BackOffExecutor(Integer maxRetries, Duration initialBackOff) {
+ this.maxRetries = maxRetries;
+ this.initialBackOff = initialBackOff;
+ }
+
+ public void execute(Retriable retriable) throws Exception {
+ FluentBackoff backoff =
+ FluentBackoff.DEFAULT.withMaxRetries(maxRetries).withInitialBackoff(initialBackOff);
+ execute(retriable, backoff);
+ }
+
+ private void execute(Retriable retriable, FluentBackoff backoff) throws Exception {
+ Sleeper sleeper = Sleeper.DEFAULT;
+ BackOff backOff = backoff.backoff();
+ while (true) {
+ try {
+ retriable.execute();
+ break;
+ } catch (Exception e) {
+ if (retriable.isExceptionRetriable(e) && BackOffUtils.next(sleeper, backOff)) {
+ retriable.cleanUpAfterFailure();
+ } else {
+ throw e;
+ }
+ }
+ }
+ }
+}
diff --git a/storage/api/src/main/java/feast/storage/common/retry/Retriable.java b/storage/api/src/main/java/feast/storage/common/retry/Retriable.java
new file mode 100644
index 0000000000..2c92c85175
--- /dev/null
+++ b/storage/api/src/main/java/feast/storage/common/retry/Retriable.java
@@ -0,0 +1,25 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.common.retry;
+
+public interface Retriable {
+ void execute() throws Exception;
+
+ Boolean isExceptionRetriable(Exception e);
+
+ void cleanUpAfterFailure();
+}
diff --git a/storage/connectors/redis/pom.xml b/storage/connectors/redis/pom.xml
index 892838efc9..6c50895bd2 100644
--- a/storage/connectors/redis/pom.xml
+++ b/storage/connectors/redis/pom.xml
@@ -13,6 +13,64 @@
Feast Storage Connector for Redis
+
+ io.lettuce
+ lettuce-core
+
+
+
+ org.apache.commons
+ commons-lang3
+ 3.9
+
+
+
+ com.google.auto.value
+ auto-value-annotations
+ 1.6.6
+
+
+
+ com.google.auto.value
+ auto-value
+ 1.6.6
+ provided
+
+
+
+ org.mockito
+ mockito-core
+ 2.23.0
+ test
+
+
+
+
+ com.github.kstyrc
+ embedded-redis
+ test
+
+
+
+ org.apache.beam
+ beam-runners-direct-java
+ ${org.apache.beam.version}
+ test
+
+
+
+ org.hamcrest
+ hamcrest-core
+ test
+
+
+
+ org.hamcrest
+ hamcrest-library
+ test
+
+
+
junit
junit
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/FeatureRowDecoder.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/FeatureRowDecoder.java
new file mode 100644
index 0000000000..7afea1b972
--- /dev/null
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/FeatureRowDecoder.java
@@ -0,0 +1,95 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.retrieval;
+
+import feast.core.FeatureSetProto.FeatureSetSpec;
+import feast.core.FeatureSetProto.FeatureSpec;
+import feast.types.FeatureRowProto.FeatureRow;
+import feast.types.FieldProto.Field;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+public class FeatureRowDecoder {
+
+ private final String featureSetRef;
+ private final FeatureSetSpec spec;
+
+ public FeatureRowDecoder(String featureSetRef, FeatureSetSpec spec) {
+ this.featureSetRef = featureSetRef;
+ this.spec = spec;
+ }
+
+ /**
+ * A feature row is considered encoded if the feature set and field names are not set. This method
+ * is required for backward compatibility purposes, to allow Feast serving to continue serving non
+ * encoded Feature Row ingested by an older version of Feast.
+ *
+ * @param featureRow Feature row
+ * @return boolean
+ */
+ public Boolean isEncoded(FeatureRow featureRow) {
+ return featureRow.getFeatureSet().isEmpty()
+ && featureRow.getFieldsList().stream().allMatch(field -> field.getName().isEmpty());
+ }
+
+ /**
+ * Validates if an encoded feature row can be decoded without exception.
+ *
+ * @param featureRow Feature row
+ * @return boolean
+ */
+ public Boolean isEncodingValid(FeatureRow featureRow) {
+ return featureRow.getFieldsList().size() == spec.getFeaturesList().size();
+ }
+
+ /**
+ * Decoding feature row by repopulating the field names based on the corresponding feature set
+ * spec.
+ *
+ * @param encodedFeatureRow Feature row
+ * @return boolean
+ */
+ public FeatureRow decode(FeatureRow encodedFeatureRow) {
+ final List fieldsWithoutName = encodedFeatureRow.getFieldsList();
+
+ List featureNames =
+ spec.getFeaturesList().stream()
+ .sorted(Comparator.comparing(FeatureSpec::getName))
+ .map(FeatureSpec::getName)
+ .collect(Collectors.toList());
+ List fields =
+ IntStream.range(0, featureNames.size())
+ .mapToObj(
+ featureNameIndex -> {
+ String featureName = featureNames.get(featureNameIndex);
+ return fieldsWithoutName
+ .get(featureNameIndex)
+ .toBuilder()
+ .setName(featureName)
+ .build();
+ })
+ .collect(Collectors.toList());
+ return encodedFeatureRow
+ .toBuilder()
+ .clearFields()
+ .setFeatureSet(featureSetRef)
+ .addAllFields(fields)
+ .build();
+ }
+}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetriever.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetriever.java
new file mode 100644
index 0000000000..7ff925a722
--- /dev/null
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetriever.java
@@ -0,0 +1,198 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.retrieval;
+
+import com.google.protobuf.AbstractMessageLite;
+import com.google.protobuf.InvalidProtocolBufferException;
+import feast.core.FeatureSetProto.EntitySpec;
+import feast.core.FeatureSetProto.FeatureSetSpec;
+import feast.serving.ServingAPIProto.FeatureReference;
+import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow;
+import feast.storage.RedisProto.RedisKey;
+import feast.storage.api.retrieval.FeatureSetRequest;
+import feast.storage.api.retrieval.OnlineRetriever;
+import feast.types.FeatureRowProto.FeatureRow;
+import feast.types.FieldProto.Field;
+import feast.types.ValueProto.Value;
+import io.grpc.Status;
+import io.lettuce.core.api.StatefulRedisConnection;
+import io.lettuce.core.api.sync.RedisCommands;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
+
+public class RedisOnlineRetriever implements OnlineRetriever {
+
+ private final RedisCommands syncCommands;
+
+ public RedisOnlineRetriever(StatefulRedisConnection connection) {
+ this.syncCommands = connection.sync();
+ }
+
+ /**
+ * Gets online features from redis. This method returns a list of {@link FeatureRow}s
+ * corresponding to each feature set spec. Each feature row in the list then corresponds to an
+ * {@link EntityRow} provided by the user.
+ *
+ * @param entityRows list of entity rows in the feature request
+ * @param featureSetRequests Map of {@link feast.core.FeatureSetProto.FeatureSetSpec} to feature
+ * references in the request tied to that feature set.
+ * @return List of List of {@link FeatureRow}
+ */
+ @Override
+ public List> getOnlineFeatures(
+ List entityRows, List featureSetRequests) {
+
+ List> featureRows = new ArrayList<>();
+ for (FeatureSetRequest featureSetRequest : featureSetRequests) {
+ List redisKeys = buildRedisKeys(entityRows, featureSetRequest.getSpec());
+ try {
+ List featureRowsForFeatureSet =
+ sendAndProcessMultiGet(
+ redisKeys,
+ featureSetRequest.getSpec(),
+ featureSetRequest.getFeatureReferences().asList());
+ featureRows.add(featureRowsForFeatureSet);
+ } catch (InvalidProtocolBufferException | ExecutionException e) {
+ throw Status.INTERNAL
+ .withDescription("Unable to parse protobuf while retrieving feature")
+ .withCause(e)
+ .asRuntimeException();
+ }
+ }
+ return featureRows;
+ }
+
+ private List buildRedisKeys(List entityRows, FeatureSetSpec featureSetSpec) {
+ String featureSetRef = generateFeatureSetStringRef(featureSetSpec);
+ List featureSetEntityNames =
+ featureSetSpec.getEntitiesList().stream()
+ .map(EntitySpec::getName)
+ .collect(Collectors.toList());
+ List redisKeys =
+ entityRows.stream()
+ .map(row -> makeRedisKey(featureSetRef, featureSetEntityNames, row))
+ .collect(Collectors.toList());
+ return redisKeys;
+ }
+
+ /**
+ * Create {@link RedisKey}
+ *
+ * @param featureSet featureSet reference of the feature. E.g. feature_set_1:1
+ * @param featureSetEntityNames entity names that belong to the featureSet
+ * @param entityRow entityRow to build the key from
+ * @return {@link RedisKey}
+ */
+ private RedisKey makeRedisKey(
+ String featureSet, List featureSetEntityNames, EntityRow entityRow) {
+ RedisKey.Builder builder = RedisKey.newBuilder().setFeatureSet(featureSet);
+ Map fieldsMap = entityRow.getFieldsMap();
+ featureSetEntityNames.sort(String::compareTo);
+ for (int i = 0; i < featureSetEntityNames.size(); i++) {
+ String entityName = featureSetEntityNames.get(i);
+
+ if (!fieldsMap.containsKey(entityName)) {
+ throw Status.INVALID_ARGUMENT
+ .withDescription(
+ String.format(
+ "Entity row fields \"%s\" does not contain required entity field \"%s\"",
+ fieldsMap.keySet().toString(), entityName))
+ .asRuntimeException();
+ }
+
+ builder.addEntities(
+ Field.newBuilder().setName(entityName).setValue(fieldsMap.get(entityName)));
+ }
+ return builder.build();
+ }
+
+ private List sendAndProcessMultiGet(
+ List redisKeys,
+ FeatureSetSpec featureSetSpec,
+ List featureReferences)
+ throws InvalidProtocolBufferException, ExecutionException {
+
+ List values = sendMultiGet(redisKeys);
+ List featureRows = new ArrayList<>();
+
+ FeatureRow.Builder nullFeatureRowBuilder =
+ FeatureRow.newBuilder().setFeatureSet(generateFeatureSetStringRef(featureSetSpec));
+ for (FeatureReference featureReference : featureReferences) {
+ nullFeatureRowBuilder.addFields(Field.newBuilder().setName(featureReference.getName()));
+ }
+
+ for (int i = 0; i < values.size(); i++) {
+
+ byte[] value = values.get(i);
+ if (value == null) {
+ featureRows.add(nullFeatureRowBuilder.build());
+ continue;
+ }
+
+ FeatureRow featureRow = FeatureRow.parseFrom(value);
+ String featureSetRef = redisKeys.get(i).getFeatureSet();
+ FeatureRowDecoder decoder = new FeatureRowDecoder(featureSetRef, featureSetSpec);
+ if (decoder.isEncoded(featureRow)) {
+ if (decoder.isEncodingValid(featureRow)) {
+ featureRow = decoder.decode(featureRow);
+ } else {
+ featureRows.add(nullFeatureRowBuilder.build());
+ continue;
+ }
+ }
+
+ featureRows.add(featureRow);
+ }
+ return featureRows;
+ }
+
+ /**
+ * Send a list of get request as an mget
+ *
+ * @param keys list of {@link RedisKey}
+ * @return list of {@link FeatureRow} in primitive byte representation for each {@link RedisKey}
+ */
+ private List sendMultiGet(List keys) {
+ try {
+ byte[][] binaryKeys =
+ keys.stream()
+ .map(AbstractMessageLite::toByteArray)
+ .collect(Collectors.toList())
+ .toArray(new byte[0][0]);
+ return syncCommands.mget(binaryKeys).stream()
+ .map(keyValue -> keyValue.getValueOrElse(null))
+ .collect(Collectors.toList());
+ } catch (Exception e) {
+ throw Status.NOT_FOUND
+ .withDescription("Unable to retrieve feature from Redis")
+ .withCause(e)
+ .asRuntimeException();
+ }
+ }
+
+ // TODO: Refactor this out to common package?
+ private static String generateFeatureSetStringRef(FeatureSetSpec featureSetSpec) {
+ String ref = String.format("%s/%s", featureSetSpec.getProject(), featureSetSpec.getName());
+ if (featureSetSpec.getVersion() > 0) {
+ return ref + String.format(":%d", featureSetSpec.getVersion());
+ }
+ return ref;
+ }
+}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisCustomIO.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisCustomIO.java
new file mode 100644
index 0000000000..e53861b99d
--- /dev/null
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisCustomIO.java
@@ -0,0 +1,292 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2019 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.write;
+
+import feast.core.FeatureSetProto.EntitySpec;
+import feast.core.FeatureSetProto.FeatureSetSpec;
+import feast.core.FeatureSetProto.FeatureSpec;
+import feast.core.StoreProto.Store.RedisConfig;
+import feast.storage.RedisProto.RedisKey;
+import feast.storage.RedisProto.RedisKey.Builder;
+import feast.storage.api.write.FailedElement;
+import feast.storage.api.write.WriteResult;
+import feast.storage.common.retry.Retriable;
+import feast.types.FeatureRowProto.FeatureRow;
+import feast.types.FieldProto.Field;
+import feast.types.ValueProto;
+import io.lettuce.core.RedisConnectionException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class RedisCustomIO {
+
+ private static final int DEFAULT_BATCH_SIZE = 1000;
+ private static final int DEFAULT_TIMEOUT = 2000;
+
+ private static TupleTag successfulInsertsTag = new TupleTag<>("successfulInserts") {};
+ private static TupleTag failedInsertsTupleTag = new TupleTag<>("failedInserts") {};
+
+ private static final Logger log = LoggerFactory.getLogger(RedisCustomIO.class);
+
+ private RedisCustomIO() {}
+
+ public static Write write(RedisConfig redisConfig, Map featureSetSpecs) {
+ return new Write(redisConfig, featureSetSpecs);
+ }
+
+ /** ServingStoreWrite data to a Redis server. */
+ public static class Write extends PTransform, WriteResult> {
+
+ private Map featureSetSpecs;
+ private RedisConfig redisConfig;
+ private int batchSize;
+ private int timeout;
+
+ public Write(RedisConfig redisConfig, Map featureSetSpecs) {
+
+ this.redisConfig = redisConfig;
+ this.featureSetSpecs = featureSetSpecs;
+ }
+
+ public Write withBatchSize(int batchSize) {
+ this.batchSize = batchSize;
+ return this;
+ }
+
+ public Write withTimeout(int timeout) {
+ this.timeout = timeout;
+ return this;
+ }
+
+ @Override
+ public WriteResult expand(PCollection input) {
+ PCollectionTuple redisWrite =
+ input.apply(
+ ParDo.of(new WriteDoFn(redisConfig, featureSetSpecs))
+ .withOutputTags(successfulInsertsTag, TupleTagList.of(failedInsertsTupleTag)));
+ return WriteResult.in(
+ input.getPipeline(),
+ redisWrite.get(successfulInsertsTag),
+ redisWrite.get(failedInsertsTupleTag));
+ }
+
+ public static class WriteDoFn extends DoFn {
+
+ private final List featureRows = new ArrayList<>();
+ private Map featureSetSpecs;
+ private int batchSize = DEFAULT_BATCH_SIZE;
+ private int timeout = DEFAULT_TIMEOUT;
+ private RedisIngestionClient redisIngestionClient;
+
+ WriteDoFn(RedisConfig config, Map featureSetSpecs) {
+
+ this.redisIngestionClient = new RedisStandaloneIngestionClient(config);
+ this.featureSetSpecs = featureSetSpecs;
+ }
+
+ public WriteDoFn withBatchSize(int batchSize) {
+ if (batchSize > 0) {
+ this.batchSize = batchSize;
+ }
+ return this;
+ }
+
+ public WriteDoFn withTimeout(int timeout) {
+ if (timeout > 0) {
+ this.timeout = timeout;
+ }
+ return this;
+ }
+
+ @Setup
+ public void setup() {
+ this.redisIngestionClient.setup();
+ }
+
+ @StartBundle
+ public void startBundle() {
+ try {
+ redisIngestionClient.connect();
+ } catch (RedisConnectionException e) {
+ log.error("Connection to redis cannot be established ", e);
+ }
+ featureRows.clear();
+ }
+
+ private void executeBatch() throws Exception {
+ this.redisIngestionClient
+ .getBackOffExecutor()
+ .execute(
+ new Retriable() {
+ @Override
+ public void execute() throws ExecutionException, InterruptedException {
+ if (!redisIngestionClient.isConnected()) {
+ redisIngestionClient.connect();
+ }
+ featureRows.forEach(
+ row -> {
+ redisIngestionClient.set(getKey(row), getValue(row));
+ });
+ redisIngestionClient.sync();
+ }
+
+ @Override
+ public Boolean isExceptionRetriable(Exception e) {
+ return e instanceof RedisConnectionException;
+ }
+
+ @Override
+ public void cleanUpAfterFailure() {}
+ });
+ }
+
+ private FailedElement toFailedElement(
+ FeatureRow featureRow, Exception exception, String jobName) {
+ return FailedElement.newBuilder()
+ .setJobName(jobName)
+ .setTransformName("RedisCustomIO")
+ .setPayload(featureRow.toString())
+ .setErrorMessage(exception.getMessage())
+ .setStackTrace(ExceptionUtils.getStackTrace(exception))
+ .build();
+ }
+
+ private byte[] getKey(FeatureRow featureRow) {
+ FeatureSetSpec featureSetSpec = featureSetSpecs.get(featureRow.getFeatureSet());
+ List entityNames =
+ featureSetSpec.getEntitiesList().stream()
+ .map(EntitySpec::getName)
+ .sorted()
+ .collect(Collectors.toList());
+
+ Map entityFields = new HashMap<>();
+ Builder redisKeyBuilder = RedisKey.newBuilder().setFeatureSet(featureRow.getFeatureSet());
+ for (Field field : featureRow.getFieldsList()) {
+ if (entityNames.contains(field.getName())) {
+ entityFields.putIfAbsent(
+ field.getName(),
+ Field.newBuilder().setName(field.getName()).setValue(field.getValue()).build());
+ }
+ }
+ for (String entityName : entityNames) {
+ redisKeyBuilder.addEntities(entityFields.get(entityName));
+ }
+ return redisKeyBuilder.build().toByteArray();
+ }
+
+ private byte[] getValue(FeatureRow featureRow) {
+ FeatureSetSpec spec = featureSetSpecs.get(featureRow.getFeatureSet());
+
+ List featureNames =
+ spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList());
+ Map fieldValueOnlyMap =
+ featureRow.getFieldsList().stream()
+ .filter(field -> featureNames.contains(field.getName()))
+ .distinct()
+ .collect(
+ Collectors.toMap(
+ Field::getName,
+ field -> Field.newBuilder().setValue(field.getValue()).build()));
+
+ List values =
+ featureNames.stream()
+ .sorted()
+ .map(
+ featureName ->
+ fieldValueOnlyMap.getOrDefault(
+ featureName,
+ Field.newBuilder()
+ .setValue(ValueProto.Value.getDefaultInstance())
+ .build()))
+ .collect(Collectors.toList());
+
+ return FeatureRow.newBuilder()
+ .setEventTimestamp(featureRow.getEventTimestamp())
+ .addAllFields(values)
+ .build()
+ .toByteArray();
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ FeatureRow featureRow = context.element();
+ featureRows.add(featureRow);
+ if (featureRows.size() >= batchSize) {
+ try {
+ executeBatch();
+ featureRows.forEach(row -> context.output(successfulInsertsTag, row));
+ featureRows.clear();
+ } catch (Exception e) {
+ featureRows.forEach(
+ failedMutation -> {
+ FailedElement failedElement =
+ toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName());
+ context.output(failedInsertsTupleTag, failedElement);
+ });
+ featureRows.clear();
+ }
+ }
+ }
+
+ @FinishBundle
+ public void finishBundle(FinishBundleContext context)
+ throws IOException, InterruptedException {
+ if (featureRows.size() > 0) {
+ try {
+ executeBatch();
+ featureRows.forEach(
+ row ->
+ context.output(
+ successfulInsertsTag, row, Instant.now(), GlobalWindow.INSTANCE));
+ featureRows.clear();
+ } catch (Exception e) {
+ featureRows.forEach(
+ failedMutation -> {
+ FailedElement failedElement =
+ toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName());
+ context.output(
+ failedInsertsTupleTag, failedElement, Instant.now(), GlobalWindow.INSTANCE);
+ });
+ featureRows.clear();
+ }
+ }
+ }
+
+ @Teardown
+ public void teardown() {
+ redisIngestionClient.shutdown();
+ }
+ }
+ }
+}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisFeatureSink.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisFeatureSink.java
new file mode 100644
index 0000000000..2f566bb78c
--- /dev/null
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisFeatureSink.java
@@ -0,0 +1,74 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.write;
+
+import com.google.auto.value.AutoValue;
+import feast.core.FeatureSetProto.FeatureSet;
+import feast.core.FeatureSetProto.FeatureSetSpec;
+import feast.core.StoreProto.Store.RedisConfig;
+import feast.storage.api.write.FeatureSink;
+import feast.storage.api.write.WriteResult;
+import feast.types.FeatureRowProto.FeatureRow;
+import io.lettuce.core.RedisClient;
+import io.lettuce.core.RedisConnectionException;
+import io.lettuce.core.RedisURI;
+import java.util.Map;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+
+@AutoValue
+public abstract class RedisFeatureSink implements FeatureSink {
+
+ public abstract RedisConfig getRedisConfig();
+
+ public abstract Map getFeatureSetSpecs();
+
+ public abstract Builder toBuilder();
+
+ public static Builder builder() {
+ return new AutoValue_RedisFeatureSink.Builder();
+ }
+
+ @AutoValue.Builder
+ public abstract static class Builder {
+ public abstract Builder setRedisConfig(RedisConfig redisConfig);
+
+ public abstract Builder setFeatureSetSpecs(Map featureSetSpecs);
+
+ public abstract RedisFeatureSink build();
+ }
+
+ @Override
+ public void prepareWrite(FeatureSet featureSet) {
+ RedisClient redisClient =
+ RedisClient.create(RedisURI.create(getRedisConfig().getHost(), getRedisConfig().getPort()));
+ try {
+ redisClient.connect();
+ } catch (RedisConnectionException e) {
+ throw new RuntimeException(
+ String.format(
+ "Failed to connect to Redis at host: '%s' port: '%d'. Please check that your Redis is running and accessible from Feast.",
+ getRedisConfig().getHost(), getRedisConfig().getPort()));
+ }
+ redisClient.shutdown();
+ }
+
+ @Override
+ public PTransform, WriteResult> write() {
+ return new RedisCustomIO.Write(getRedisConfig(), getFeatureSetSpecs());
+ }
+}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisIngestionClient.java
new file mode 100644
index 0000000000..7004b94282
--- /dev/null
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisIngestionClient.java
@@ -0,0 +1,49 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.write;
+
+import feast.storage.common.retry.BackOffExecutor;
+import java.io.Serializable;
+
+public interface RedisIngestionClient extends Serializable {
+
+ void setup();
+
+ BackOffExecutor getBackOffExecutor();
+
+ void shutdown();
+
+ void connect();
+
+ boolean isConnected();
+
+ void sync();
+
+ void pexpire(byte[] key, Long expiryMillis);
+
+ void append(byte[] key, byte[] value);
+
+ void set(byte[] key, byte[] value);
+
+ void lpush(byte[] key, byte[] value);
+
+ void rpush(byte[] key, byte[] value);
+
+ void sadd(byte[] key, byte[] value);
+
+ void zadd(byte[] key, Long score, byte[] value);
+}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisStandaloneIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisStandaloneIngestionClient.java
new file mode 100644
index 0000000000..583eade2aa
--- /dev/null
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisStandaloneIngestionClient.java
@@ -0,0 +1,125 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.write;
+
+import com.google.common.collect.Lists;
+import feast.core.StoreProto;
+import feast.storage.common.retry.BackOffExecutor;
+import io.lettuce.core.LettuceFutures;
+import io.lettuce.core.RedisClient;
+import io.lettuce.core.RedisFuture;
+import io.lettuce.core.RedisURI;
+import io.lettuce.core.api.StatefulRedisConnection;
+import io.lettuce.core.api.async.RedisAsyncCommands;
+import io.lettuce.core.codec.ByteArrayCodec;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+import org.joda.time.Duration;
+
+public class RedisStandaloneIngestionClient implements RedisIngestionClient {
+ private final String host;
+ private final int port;
+ private final BackOffExecutor backOffExecutor;
+ private RedisClient redisclient;
+ private static final int DEFAULT_TIMEOUT = 2000;
+ private StatefulRedisConnection connection;
+ private RedisAsyncCommands commands;
+ private List futures = Lists.newArrayList();
+
+ public RedisStandaloneIngestionClient(StoreProto.Store.RedisConfig redisConfig) {
+ this.host = redisConfig.getHost();
+ this.port = redisConfig.getPort();
+ long backoffMs = redisConfig.getInitialBackoffMs() > 0 ? redisConfig.getInitialBackoffMs() : 1;
+ this.backOffExecutor =
+ new BackOffExecutor(redisConfig.getMaxRetries(), Duration.millis(backoffMs));
+ }
+
+ @Override
+ public void setup() {
+ this.redisclient =
+ RedisClient.create(new RedisURI(host, port, java.time.Duration.ofMillis(DEFAULT_TIMEOUT)));
+ }
+
+ @Override
+ public BackOffExecutor getBackOffExecutor() {
+ return this.backOffExecutor;
+ }
+
+ @Override
+ public void shutdown() {
+ this.redisclient.shutdown();
+ }
+
+ @Override
+ public void connect() {
+ if (!isConnected()) {
+ this.connection = this.redisclient.connect(new ByteArrayCodec());
+ this.commands = connection.async();
+ }
+ }
+
+ @Override
+ public boolean isConnected() {
+ return connection != null;
+ }
+
+ @Override
+ public void sync() {
+ // Wait for some time for futures to complete
+ // TODO: should this be configurable?
+ try {
+ LettuceFutures.awaitAll(60, TimeUnit.SECONDS, futures.toArray(new RedisFuture[0]));
+ } finally {
+ futures.clear();
+ }
+ }
+
+ @Override
+ public void pexpire(byte[] key, Long expiryMillis) {
+ commands.pexpire(key, expiryMillis);
+ }
+
+ @Override
+ public void append(byte[] key, byte[] value) {
+ futures.add(commands.append(key, value));
+ }
+
+ @Override
+ public void set(byte[] key, byte[] value) {
+ futures.add(commands.set(key, value));
+ }
+
+ @Override
+ public void lpush(byte[] key, byte[] value) {
+ futures.add(commands.lpush(key, value));
+ }
+
+ @Override
+ public void rpush(byte[] key, byte[] value) {
+ futures.add(commands.rpush(key, value));
+ }
+
+ @Override
+ public void sadd(byte[] key, byte[] value) {
+ futures.add(commands.sadd(key, value));
+ }
+
+ @Override
+ public void zadd(byte[] key, Long score, byte[] value) {
+ futures.add(commands.zadd(key, score, value));
+ }
+}
diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetrieverTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetrieverTest.java
new file mode 100644
index 0000000000..9fabc5fae7
--- /dev/null
+++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetrieverTest.java
@@ -0,0 +1,275 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.retrieval;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.Mockito.when;
+import static org.mockito.MockitoAnnotations.initMocks;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import com.google.protobuf.AbstractMessageLite;
+import com.google.protobuf.Duration;
+import com.google.protobuf.Timestamp;
+import feast.core.FeatureSetProto.EntitySpec;
+import feast.core.FeatureSetProto.FeatureSetSpec;
+import feast.core.FeatureSetProto.FeatureSpec;
+import feast.serving.ServingAPIProto.FeatureReference;
+import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow;
+import feast.storage.RedisProto.RedisKey;
+import feast.storage.api.retrieval.FeatureSetRequest;
+import feast.types.FeatureRowProto.FeatureRow;
+import feast.types.FieldProto.Field;
+import feast.types.ValueProto.Value;
+import io.lettuce.core.KeyValue;
+import io.lettuce.core.api.StatefulRedisConnection;
+import io.lettuce.core.api.sync.RedisCommands;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+
+public class RedisOnlineRetrieverTest {
+
+ @Mock StatefulRedisConnection connection;
+
+ @Mock RedisCommands syncCommands;
+
+ private RedisOnlineRetriever redisOnlineRetriever;
+ private byte[][] redisKeyList;
+
+ @Before
+ public void setUp() {
+ initMocks(this);
+ when(connection.sync()).thenReturn(syncCommands);
+ redisOnlineRetriever = new RedisOnlineRetriever(connection);
+ redisKeyList =
+ Lists.newArrayList(
+ RedisKey.newBuilder()
+ .setFeatureSet("project/featureSet:1")
+ .addAllEntities(
+ Lists.newArrayList(
+ Field.newBuilder().setName("entity1").setValue(intValue(1)).build(),
+ Field.newBuilder().setName("entity2").setValue(strValue("a")).build()))
+ .build(),
+ RedisKey.newBuilder()
+ .setFeatureSet("project/featureSet:1")
+ .addAllEntities(
+ Lists.newArrayList(
+ Field.newBuilder().setName("entity1").setValue(intValue(2)).build(),
+ Field.newBuilder().setName("entity2").setValue(strValue("b")).build()))
+ .build())
+ .stream()
+ .map(AbstractMessageLite::toByteArray)
+ .collect(Collectors.toList())
+ .toArray(new byte[0][0]);
+ }
+
+ @Test
+ public void shouldReturnResponseWithValuesIfKeysPresent() {
+ FeatureSetRequest featureSetRequest =
+ FeatureSetRequest.newBuilder()
+ .setSpec(getFeatureSetSpec())
+ .addFeatureReference(
+ FeatureReference.newBuilder()
+ .setName("feature1")
+ .setVersion(1)
+ .setProject("project")
+ .build())
+ .addFeatureReference(
+ FeatureReference.newBuilder()
+ .setName("feature2")
+ .setVersion(1)
+ .setProject("project")
+ .build())
+ .build();
+ List entityRows =
+ ImmutableList.of(
+ EntityRow.newBuilder()
+ .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .putFields("entity1", intValue(1))
+ .putFields("entity2", strValue("a"))
+ .build(),
+ EntityRow.newBuilder()
+ .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .putFields("entity1", intValue(2))
+ .putFields("entity2", strValue("b"))
+ .build());
+
+ List featureRows =
+ Lists.newArrayList(
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .addAllFields(
+ Lists.newArrayList(
+ Field.newBuilder().setValue(intValue(1)).build(),
+ Field.newBuilder().setValue(intValue(1)).build()))
+ .build(),
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .addAllFields(
+ Lists.newArrayList(
+ Field.newBuilder().setValue(intValue(2)).build(),
+ Field.newBuilder().setValue(intValue(2)).build()))
+ .build());
+
+ List> featureRowBytes =
+ featureRows.stream()
+ .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray())))
+ .collect(Collectors.toList());
+
+ redisOnlineRetriever = new RedisOnlineRetriever(connection);
+ when(connection.sync()).thenReturn(syncCommands);
+ when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes);
+
+ List> expected =
+ List.of(
+ Lists.newArrayList(
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .setFeatureSet("project/featureSet:1")
+ .addAllFields(
+ Lists.newArrayList(
+ Field.newBuilder().setName("feature1").setValue(intValue(1)).build(),
+ Field.newBuilder().setName("feature2").setValue(intValue(1)).build()))
+ .build(),
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .setFeatureSet("project/featureSet:1")
+ .addAllFields(
+ Lists.newArrayList(
+ Field.newBuilder().setName("feature1").setValue(intValue(2)).build(),
+ Field.newBuilder().setName("feature2").setValue(intValue(2)).build()))
+ .build()));
+
+ List> actual =
+ redisOnlineRetriever.getOnlineFeatures(entityRows, List.of(featureSetRequest));
+ assertThat(actual, equalTo(expected));
+ }
+
+ @Test
+ public void shouldReturnResponseWithUnsetValuesIfKeysNotPresent() {
+ FeatureSetRequest featureSetRequest =
+ FeatureSetRequest.newBuilder()
+ .setSpec(getFeatureSetSpec())
+ .addFeatureReference(
+ FeatureReference.newBuilder()
+ .setName("feature1")
+ .setVersion(1)
+ .setProject("project")
+ .build())
+ .addFeatureReference(
+ FeatureReference.newBuilder()
+ .setName("feature2")
+ .setVersion(1)
+ .setProject("project")
+ .build())
+ .build();
+ List entityRows =
+ ImmutableList.of(
+ EntityRow.newBuilder()
+ .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .putFields("entity1", intValue(1))
+ .putFields("entity2", strValue("a"))
+ .build(),
+ EntityRow.newBuilder()
+ .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .putFields("entity1", intValue(2))
+ .putFields("entity2", strValue("b"))
+ .build());
+
+ List featureRows =
+ Lists.newArrayList(
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .addAllFields(
+ Lists.newArrayList(
+ Field.newBuilder().setValue(intValue(1)).build(),
+ Field.newBuilder().setValue(intValue(1)).build()))
+ .build());
+
+ List> featureRowBytes =
+ featureRows.stream()
+ .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray())))
+ .collect(Collectors.toList());
+ featureRowBytes.add(null);
+
+ redisOnlineRetriever = new RedisOnlineRetriever(connection);
+ when(connection.sync()).thenReturn(syncCommands);
+ when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes);
+
+ List> expected =
+ List.of(
+ Lists.newArrayList(
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .setFeatureSet("project/featureSet:1")
+ .addAllFields(
+ Lists.newArrayList(
+ Field.newBuilder().setName("feature1").setValue(intValue(1)).build(),
+ Field.newBuilder().setName("feature2").setValue(intValue(1)).build()))
+ .build(),
+ FeatureRow.newBuilder()
+ .setFeatureSet("project/featureSet:1")
+ .addAllFields(
+ Lists.newArrayList(
+ Field.newBuilder().setName("feature1").build(),
+ Field.newBuilder().setName("feature2").build()))
+ .build()));
+
+ List> actual =
+ redisOnlineRetriever.getOnlineFeatures(entityRows, List.of(featureSetRequest));
+ assertThat(actual, equalTo(expected));
+ }
+
+ private Value intValue(int val) {
+ return Value.newBuilder().setInt64Val(val).build();
+ }
+
+ private Value strValue(String val) {
+ return Value.newBuilder().setStringVal(val).build();
+ }
+
+ private FeatureSetSpec getFeatureSetSpec() {
+ return FeatureSetSpec.newBuilder()
+ .setProject("project")
+ .setName("featureSet")
+ .setVersion(1)
+ .addEntities(EntitySpec.newBuilder().setName("entity1"))
+ .addEntities(EntitySpec.newBuilder().setName("entity2"))
+ .addFeatures(FeatureSpec.newBuilder().setName("feature1"))
+ .addFeatures(FeatureSpec.newBuilder().setName("feature2"))
+ .setMaxAge(Duration.newBuilder().setSeconds(30)) // default
+ .build();
+ }
+
+ private FeatureSetSpec getFeatureSetSpecWithNoMaxAge() {
+ return FeatureSetSpec.newBuilder()
+ .setProject("project")
+ .setName("featureSet")
+ .setVersion(1)
+ .addEntities(EntitySpec.newBuilder().setName("entity1"))
+ .addEntities(EntitySpec.newBuilder().setName("entity2"))
+ .addFeatures(FeatureSpec.newBuilder().setName("feature1"))
+ .addFeatures(FeatureSpec.newBuilder().setName("feature2"))
+ .setMaxAge(Duration.newBuilder().setSeconds(0).setNanos(0).build())
+ .build();
+ }
+}
diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java
new file mode 100644
index 0000000000..66aba44bc2
--- /dev/null
+++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java
@@ -0,0 +1,44 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.test;
+
+import java.io.IOException;
+import redis.embedded.RedisServer;
+
+public class TestUtil {
+ public static class LocalRedis {
+
+ private static RedisServer server;
+
+ /**
+ * Start local Redis for used in testing at "localhost"
+ *
+ * @param port port number
+ * @throws IOException if Redis failed to start
+ */
+ public static void start(int port) throws IOException {
+ server = new RedisServer(port);
+ server.start();
+ }
+
+ public static void stop() {
+ if (server != null) {
+ server.stop();
+ }
+ }
+ }
+}
diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/write/RedisFeatureSinkTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/write/RedisFeatureSinkTest.java
new file mode 100644
index 0000000000..ddabed8fad
--- /dev/null
+++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/write/RedisFeatureSinkTest.java
@@ -0,0 +1,479 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2019 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.write;
+
+import static feast.storage.common.testing.TestUtil.field;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.protobuf.Timestamp;
+import feast.core.FeatureSetProto.EntitySpec;
+import feast.core.FeatureSetProto.FeatureSetSpec;
+import feast.core.FeatureSetProto.FeatureSpec;
+import feast.core.StoreProto;
+import feast.core.StoreProto.Store.RedisConfig;
+import feast.storage.RedisProto.RedisKey;
+import feast.types.FeatureRowProto.FeatureRow;
+import feast.types.FieldProto.Field;
+import feast.types.ValueProto.Value;
+import feast.types.ValueProto.ValueType.Enum;
+import io.lettuce.core.RedisClient;
+import io.lettuce.core.RedisURI;
+import io.lettuce.core.api.StatefulRedisConnection;
+import io.lettuce.core.api.sync.RedisStringCommands;
+import io.lettuce.core.codec.ByteArrayCodec;
+import java.io.IOException;
+import java.util.*;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import redis.embedded.Redis;
+import redis.embedded.RedisServer;
+
+public class RedisFeatureSinkTest {
+ @Rule public transient TestPipeline p = TestPipeline.create();
+
+ private static String REDIS_HOST = "localhost";
+ private static int REDIS_PORT = 51234;
+ private Redis redis;
+ private RedisClient redisClient;
+ private RedisStringCommands sync;
+
+ private RedisFeatureSink redisFeatureSink;
+
+ @Before
+ public void setUp() throws IOException {
+ redis = new RedisServer(REDIS_PORT);
+ redis.start();
+ redisClient =
+ RedisClient.create(new RedisURI(REDIS_HOST, REDIS_PORT, java.time.Duration.ofMillis(2000)));
+ StatefulRedisConnection connection = redisClient.connect(new ByteArrayCodec());
+ sync = connection.sync();
+
+ FeatureSetSpec spec1 =
+ FeatureSetSpec.newBuilder()
+ .setName("fs")
+ .setVersion(1)
+ .setProject("myproject")
+ .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build())
+ .addFeatures(
+ FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build())
+ .build();
+
+ FeatureSetSpec spec2 =
+ FeatureSetSpec.newBuilder()
+ .setName("feature_set")
+ .setProject("myproject")
+ .setVersion(1)
+ .addEntities(
+ EntitySpec.newBuilder()
+ .setName("entity_id_primary")
+ .setValueType(Enum.INT32)
+ .build())
+ .addEntities(
+ EntitySpec.newBuilder()
+ .setName("entity_id_secondary")
+ .setValueType(Enum.STRING)
+ .build())
+ .addFeatures(
+ FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build())
+ .addFeatures(
+ FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build())
+ .build();
+
+ Map specMap =
+ ImmutableMap.of("myproject/fs:1", spec1, "myproject/feature_set:1", spec2);
+ StoreProto.Store.RedisConfig redisConfig =
+ StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build();
+
+ redisFeatureSink =
+ RedisFeatureSink.builder().setFeatureSetSpecs(specMap).setRedisConfig(redisConfig).build();
+ }
+
+ @After
+ public void teardown() {
+ redisClient.shutdown();
+ redis.stop();
+ }
+
+ @Test
+ public void shouldWriteToRedis() {
+
+ HashMap kvs = new LinkedHashMap<>();
+ kvs.put(
+ RedisKey.newBuilder()
+ .setFeatureSet("myproject/fs:1")
+ .addEntities(field("entity", 1, Enum.INT64))
+ .build(),
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.getDefaultInstance())
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one")))
+ .build());
+ kvs.put(
+ RedisKey.newBuilder()
+ .setFeatureSet("myproject/fs:1")
+ .addEntities(field("entity", 2, Enum.INT64))
+ .build(),
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.getDefaultInstance())
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("two")))
+ .build());
+
+ List featureRows =
+ ImmutableList.of(
+ FeatureRow.newBuilder()
+ .setFeatureSet("myproject/fs:1")
+ .addFields(field("entity", 1, Enum.INT64))
+ .addFields(field("feature", "one", Enum.STRING))
+ .build(),
+ FeatureRow.newBuilder()
+ .setFeatureSet("myproject/fs:1")
+ .addFields(field("entity", 2, Enum.INT64))
+ .addFields(field("feature", "two", Enum.STRING))
+ .build());
+
+ p.apply(Create.of(featureRows)).apply(redisFeatureSink.write());
+ p.run();
+
+ kvs.forEach(
+ (key, value) -> {
+ byte[] actual = sync.get(key.toByteArray());
+ assertThat(actual, equalTo(value.toByteArray()));
+ });
+ }
+
+ @Test(timeout = 10000)
+ public void shouldRetryFailConnection() throws InterruptedException {
+ RedisConfig redisConfig =
+ RedisConfig.newBuilder()
+ .setHost(REDIS_HOST)
+ .setPort(REDIS_PORT)
+ .setMaxRetries(4)
+ .setInitialBackoffMs(2000)
+ .build();
+ redisFeatureSink = redisFeatureSink.toBuilder().setRedisConfig(redisConfig).build();
+
+ HashMap kvs = new LinkedHashMap<>();
+ kvs.put(
+ RedisKey.newBuilder()
+ .setFeatureSet("myproject/fs:1")
+ .addEntities(field("entity", 1, Enum.INT64))
+ .build(),
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.getDefaultInstance())
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one")))
+ .build());
+
+ List featureRows =
+ ImmutableList.of(
+ FeatureRow.newBuilder()
+ .setFeatureSet("myproject/fs:1")
+ .addFields(field("entity", 1, Enum.INT64))
+ .addFields(field("feature", "one", Enum.STRING))
+ .build());
+
+ PCollection failedElementCount =
+ p.apply(Create.of(featureRows))
+ .apply(redisFeatureSink.write())
+ .getFailedInserts()
+ .apply(Count.globally());
+
+ redis.stop();
+ final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1);
+ ScheduledFuture> scheduledRedisRestart =
+ redisRestartExecutor.schedule(
+ () -> {
+ redis.start();
+ },
+ 3,
+ TimeUnit.SECONDS);
+
+ PAssert.that(failedElementCount).containsInAnyOrder(0L);
+ p.run();
+ scheduledRedisRestart.cancel(true);
+
+ kvs.forEach(
+ (key, value) -> {
+ byte[] actual = sync.get(key.toByteArray());
+ assertThat(actual, equalTo(value.toByteArray()));
+ });
+ }
+
+ @Test
+ public void shouldProduceFailedElementIfRetryExceeded() {
+
+ RedisConfig redisConfig =
+ RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT + 1).build();
+ redisFeatureSink = redisFeatureSink.toBuilder().setRedisConfig(redisConfig).build();
+
+ HashMap kvs = new LinkedHashMap<>();
+ kvs.put(
+ RedisKey.newBuilder()
+ .setFeatureSet("myproject/fs:1")
+ .addEntities(field("entity", 1, Enum.INT64))
+ .build(),
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.getDefaultInstance())
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one")))
+ .build());
+
+ List featureRows =
+ ImmutableList.of(
+ FeatureRow.newBuilder()
+ .setFeatureSet("myproject/fs:1")
+ .addFields(field("entity", 1, Enum.INT64))
+ .addFields(field("feature", "one", Enum.STRING))
+ .build());
+
+ PCollection failedElementCount =
+ p.apply(Create.of(featureRows))
+ .apply(redisFeatureSink.write())
+ .getFailedInserts()
+ .apply(Count.globally());
+
+ redis.stop();
+ PAssert.that(failedElementCount).containsInAnyOrder(1L);
+ p.run();
+ }
+
+ @Test
+ public void shouldConvertRowWithDuplicateEntitiesToValidKey() {
+
+ FeatureRow offendingRow =
+ FeatureRow.newBuilder()
+ .setFeatureSet("myproject/feature_set:1")
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(1)))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(2)))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_secondary")
+ .setValue(Value.newBuilder().setStringVal("a")))
+ .addFields(
+ Field.newBuilder()
+ .setName("feature_1")
+ .setValue(Value.newBuilder().setStringVal("strValue1")))
+ .addFields(
+ Field.newBuilder()
+ .setName("feature_2")
+ .setValue(Value.newBuilder().setInt64Val(1001)))
+ .build();
+
+ RedisKey expectedKey =
+ RedisKey.newBuilder()
+ .setFeatureSet("myproject/feature_set:1")
+ .addEntities(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(1)))
+ .addEntities(
+ Field.newBuilder()
+ .setName("entity_id_secondary")
+ .setValue(Value.newBuilder().setStringVal("a")))
+ .build();
+
+ FeatureRow expectedValue =
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")))
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)))
+ .build();
+
+ p.apply(Create.of(offendingRow)).apply(redisFeatureSink.write());
+
+ p.run();
+
+ byte[] actual = sync.get(expectedKey.toByteArray());
+ assertThat(actual, equalTo(expectedValue.toByteArray()));
+ }
+
+ @Test
+ public void shouldConvertRowWithOutOfOrderFieldsToValidKey() {
+ FeatureRow offendingRow =
+ FeatureRow.newBuilder()
+ .setFeatureSet("myproject/feature_set:1")
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_secondary")
+ .setValue(Value.newBuilder().setStringVal("a")))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(1)))
+ .addFields(
+ Field.newBuilder()
+ .setName("feature_2")
+ .setValue(Value.newBuilder().setInt64Val(1001)))
+ .addFields(
+ Field.newBuilder()
+ .setName("feature_1")
+ .setValue(Value.newBuilder().setStringVal("strValue1")))
+ .build();
+
+ RedisKey expectedKey =
+ RedisKey.newBuilder()
+ .setFeatureSet("myproject/feature_set:1")
+ .addEntities(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(1)))
+ .addEntities(
+ Field.newBuilder()
+ .setName("entity_id_secondary")
+ .setValue(Value.newBuilder().setStringVal("a")))
+ .build();
+
+ List expectedFields =
+ Arrays.asList(
+ Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")).build(),
+ Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)).build());
+ FeatureRow expectedValue =
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
+ .addAllFields(expectedFields)
+ .build();
+
+ p.apply(Create.of(offendingRow)).apply(redisFeatureSink.write());
+
+ p.run();
+
+ byte[] actual = sync.get(expectedKey.toByteArray());
+ assertThat(actual, equalTo(expectedValue.toByteArray()));
+ }
+
+ @Test
+ public void shouldMergeDuplicateFeatureFields() {
+ FeatureRow featureRowWithDuplicatedFeatureFields =
+ FeatureRow.newBuilder()
+ .setFeatureSet("myproject/feature_set:1")
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(1)))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_secondary")
+ .setValue(Value.newBuilder().setStringVal("a")))
+ .addFields(
+ Field.newBuilder()
+ .setName("feature_1")
+ .setValue(Value.newBuilder().setStringVal("strValue1")))
+ .addFields(
+ Field.newBuilder()
+ .setName("feature_1")
+ .setValue(Value.newBuilder().setStringVal("strValue1")))
+ .addFields(
+ Field.newBuilder()
+ .setName("feature_2")
+ .setValue(Value.newBuilder().setInt64Val(1001)))
+ .build();
+
+ RedisKey expectedKey =
+ RedisKey.newBuilder()
+ .setFeatureSet("myproject/feature_set:1")
+ .addEntities(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(1)))
+ .addEntities(
+ Field.newBuilder()
+ .setName("entity_id_secondary")
+ .setValue(Value.newBuilder().setStringVal("a")))
+ .build();
+
+ FeatureRow expectedValue =
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")))
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)))
+ .build();
+
+ p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.write());
+
+ p.run();
+
+ byte[] actual = sync.get(expectedKey.toByteArray());
+ assertThat(actual, equalTo(expectedValue.toByteArray()));
+ }
+
+ @Test
+ public void shouldPopulateMissingFeatureValuesWithDefaultInstance() {
+ FeatureRow featureRowWithDuplicatedFeatureFields =
+ FeatureRow.newBuilder()
+ .setFeatureSet("myproject/feature_set:1")
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(1)))
+ .addFields(
+ Field.newBuilder()
+ .setName("entity_id_secondary")
+ .setValue(Value.newBuilder().setStringVal("a")))
+ .addFields(
+ Field.newBuilder()
+ .setName("feature_1")
+ .setValue(Value.newBuilder().setStringVal("strValue1")))
+ .build();
+
+ RedisKey expectedKey =
+ RedisKey.newBuilder()
+ .setFeatureSet("myproject/feature_set:1")
+ .addEntities(
+ Field.newBuilder()
+ .setName("entity_id_primary")
+ .setValue(Value.newBuilder().setInt32Val(1)))
+ .addEntities(
+ Field.newBuilder()
+ .setName("entity_id_secondary")
+ .setValue(Value.newBuilder().setStringVal("a")))
+ .build();
+
+ FeatureRow expectedValue =
+ FeatureRow.newBuilder()
+ .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
+ .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")))
+ .addFields(Field.newBuilder().setValue(Value.getDefaultInstance()))
+ .build();
+
+ p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.write());
+
+ p.run();
+
+ byte[] actual = sync.get(expectedKey.toByteArray());
+ assertThat(actual, equalTo(expectedValue.toByteArray()));
+ }
+}