From 8b00ad9743fa7f32cfc5972374d0cd37719bec7d Mon Sep 17 00:00:00 2001 From: Chen Zhiling Date: Thu, 19 Mar 2020 14:49:40 +0800 Subject: [PATCH] Add Redis storage implementation (#547) * Add Redis storage * Remove staleness check; can be checked at the service level * Remove staleness related tests * Add dependencies to top level pom * Clean up code --- .../storage/common/retry/BackOffExecutor.java | 58 +++ .../feast/storage/common/retry/Retriable.java | 25 + storage/connectors/redis/pom.xml | 58 +++ .../redis/retrieval/FeatureRowDecoder.java | 95 ++++ .../redis/retrieval/RedisOnlineRetriever.java | 198 ++++++++ .../connectors/redis/write/RedisCustomIO.java | 292 +++++++++++ .../redis/write/RedisFeatureSink.java | 74 +++ .../redis/write/RedisIngestionClient.java | 49 ++ .../write/RedisStandaloneIngestionClient.java | 125 +++++ .../retrieval/RedisOnlineRetrieverTest.java | 275 ++++++++++ .../connectors/redis/test/TestUtil.java | 44 ++ .../redis/write/RedisFeatureSinkTest.java | 479 ++++++++++++++++++ 12 files changed, 1772 insertions(+) create mode 100644 storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java create mode 100644 storage/api/src/main/java/feast/storage/common/retry/Retriable.java create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/FeatureRowDecoder.java create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetriever.java create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisCustomIO.java create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisFeatureSink.java create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisIngestionClient.java create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisStandaloneIngestionClient.java create mode 100644 storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetrieverTest.java create mode 100644 storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java create mode 100644 storage/connectors/redis/src/test/java/feast/storage/connectors/redis/write/RedisFeatureSinkTest.java diff --git a/storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java b/storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java new file mode 100644 index 0000000000..296582f8b3 --- /dev/null +++ b/storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.common.retry; + +import java.io.Serializable; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.BackOffUtils; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; +import org.joda.time.Duration; + +public class BackOffExecutor implements Serializable { + + private final Integer maxRetries; + private final Duration initialBackOff; + + public BackOffExecutor(Integer maxRetries, Duration initialBackOff) { + this.maxRetries = maxRetries; + this.initialBackOff = initialBackOff; + } + + public void execute(Retriable retriable) throws Exception { + FluentBackoff backoff = + FluentBackoff.DEFAULT.withMaxRetries(maxRetries).withInitialBackoff(initialBackOff); + execute(retriable, backoff); + } + + private void execute(Retriable retriable, FluentBackoff backoff) throws Exception { + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backOff = backoff.backoff(); + while (true) { + try { + retriable.execute(); + break; + } catch (Exception e) { + if (retriable.isExceptionRetriable(e) && BackOffUtils.next(sleeper, backOff)) { + retriable.cleanUpAfterFailure(); + } else { + throw e; + } + } + } + } +} diff --git a/storage/api/src/main/java/feast/storage/common/retry/Retriable.java b/storage/api/src/main/java/feast/storage/common/retry/Retriable.java new file mode 100644 index 0000000000..2c92c85175 --- /dev/null +++ b/storage/api/src/main/java/feast/storage/common/retry/Retriable.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.common.retry; + +public interface Retriable { + void execute() throws Exception; + + Boolean isExceptionRetriable(Exception e); + + void cleanUpAfterFailure(); +} diff --git a/storage/connectors/redis/pom.xml b/storage/connectors/redis/pom.xml index 892838efc9..6c50895bd2 100644 --- a/storage/connectors/redis/pom.xml +++ b/storage/connectors/redis/pom.xml @@ -13,6 +13,64 @@ Feast Storage Connector for Redis + + io.lettuce + lettuce-core + + + + org.apache.commons + commons-lang3 + 3.9 + + + + com.google.auto.value + auto-value-annotations + 1.6.6 + + + + com.google.auto.value + auto-value + 1.6.6 + provided + + + + org.mockito + mockito-core + 2.23.0 + test + + + + + com.github.kstyrc + embedded-redis + test + + + + org.apache.beam + beam-runners-direct-java + ${org.apache.beam.version} + test + + + + org.hamcrest + hamcrest-core + test + + + + org.hamcrest + hamcrest-library + test + + + junit junit diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/FeatureRowDecoder.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/FeatureRowDecoder.java new file mode 100644 index 0000000000..7afea1b972 --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/FeatureRowDecoder.java @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.retrieval; + +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.FeatureSetProto.FeatureSpec; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class FeatureRowDecoder { + + private final String featureSetRef; + private final FeatureSetSpec spec; + + public FeatureRowDecoder(String featureSetRef, FeatureSetSpec spec) { + this.featureSetRef = featureSetRef; + this.spec = spec; + } + + /** + * A feature row is considered encoded if the feature set and field names are not set. This method + * is required for backward compatibility purposes, to allow Feast serving to continue serving non + * encoded Feature Row ingested by an older version of Feast. + * + * @param featureRow Feature row + * @return boolean + */ + public Boolean isEncoded(FeatureRow featureRow) { + return featureRow.getFeatureSet().isEmpty() + && featureRow.getFieldsList().stream().allMatch(field -> field.getName().isEmpty()); + } + + /** + * Validates if an encoded feature row can be decoded without exception. + * + * @param featureRow Feature row + * @return boolean + */ + public Boolean isEncodingValid(FeatureRow featureRow) { + return featureRow.getFieldsList().size() == spec.getFeaturesList().size(); + } + + /** + * Decoding feature row by repopulating the field names based on the corresponding feature set + * spec. + * + * @param encodedFeatureRow Feature row + * @return boolean + */ + public FeatureRow decode(FeatureRow encodedFeatureRow) { + final List fieldsWithoutName = encodedFeatureRow.getFieldsList(); + + List featureNames = + spec.getFeaturesList().stream() + .sorted(Comparator.comparing(FeatureSpec::getName)) + .map(FeatureSpec::getName) + .collect(Collectors.toList()); + List fields = + IntStream.range(0, featureNames.size()) + .mapToObj( + featureNameIndex -> { + String featureName = featureNames.get(featureNameIndex); + return fieldsWithoutName + .get(featureNameIndex) + .toBuilder() + .setName(featureName) + .build(); + }) + .collect(Collectors.toList()); + return encodedFeatureRow + .toBuilder() + .clearFields() + .setFeatureSet(featureSetRef) + .addAllFields(fields) + .build(); + } +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetriever.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetriever.java new file mode 100644 index 0000000000..7ff925a722 --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetriever.java @@ -0,0 +1,198 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.retrieval; + +import com.google.protobuf.AbstractMessageLite; +import com.google.protobuf.InvalidProtocolBufferException; +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.serving.ServingAPIProto.FeatureReference; +import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; +import feast.storage.RedisProto.RedisKey; +import feast.storage.api.retrieval.FeatureSetRequest; +import feast.storage.api.retrieval.OnlineRetriever; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import feast.types.ValueProto.Value; +import io.grpc.Status; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +public class RedisOnlineRetriever implements OnlineRetriever { + + private final RedisCommands syncCommands; + + public RedisOnlineRetriever(StatefulRedisConnection connection) { + this.syncCommands = connection.sync(); + } + + /** + * Gets online features from redis. This method returns a list of {@link FeatureRow}s + * corresponding to each feature set spec. Each feature row in the list then corresponds to an + * {@link EntityRow} provided by the user. + * + * @param entityRows list of entity rows in the feature request + * @param featureSetRequests Map of {@link feast.core.FeatureSetProto.FeatureSetSpec} to feature + * references in the request tied to that feature set. + * @return List of List of {@link FeatureRow} + */ + @Override + public List> getOnlineFeatures( + List entityRows, List featureSetRequests) { + + List> featureRows = new ArrayList<>(); + for (FeatureSetRequest featureSetRequest : featureSetRequests) { + List redisKeys = buildRedisKeys(entityRows, featureSetRequest.getSpec()); + try { + List featureRowsForFeatureSet = + sendAndProcessMultiGet( + redisKeys, + featureSetRequest.getSpec(), + featureSetRequest.getFeatureReferences().asList()); + featureRows.add(featureRowsForFeatureSet); + } catch (InvalidProtocolBufferException | ExecutionException e) { + throw Status.INTERNAL + .withDescription("Unable to parse protobuf while retrieving feature") + .withCause(e) + .asRuntimeException(); + } + } + return featureRows; + } + + private List buildRedisKeys(List entityRows, FeatureSetSpec featureSetSpec) { + String featureSetRef = generateFeatureSetStringRef(featureSetSpec); + List featureSetEntityNames = + featureSetSpec.getEntitiesList().stream() + .map(EntitySpec::getName) + .collect(Collectors.toList()); + List redisKeys = + entityRows.stream() + .map(row -> makeRedisKey(featureSetRef, featureSetEntityNames, row)) + .collect(Collectors.toList()); + return redisKeys; + } + + /** + * Create {@link RedisKey} + * + * @param featureSet featureSet reference of the feature. E.g. feature_set_1:1 + * @param featureSetEntityNames entity names that belong to the featureSet + * @param entityRow entityRow to build the key from + * @return {@link RedisKey} + */ + private RedisKey makeRedisKey( + String featureSet, List featureSetEntityNames, EntityRow entityRow) { + RedisKey.Builder builder = RedisKey.newBuilder().setFeatureSet(featureSet); + Map fieldsMap = entityRow.getFieldsMap(); + featureSetEntityNames.sort(String::compareTo); + for (int i = 0; i < featureSetEntityNames.size(); i++) { + String entityName = featureSetEntityNames.get(i); + + if (!fieldsMap.containsKey(entityName)) { + throw Status.INVALID_ARGUMENT + .withDescription( + String.format( + "Entity row fields \"%s\" does not contain required entity field \"%s\"", + fieldsMap.keySet().toString(), entityName)) + .asRuntimeException(); + } + + builder.addEntities( + Field.newBuilder().setName(entityName).setValue(fieldsMap.get(entityName))); + } + return builder.build(); + } + + private List sendAndProcessMultiGet( + List redisKeys, + FeatureSetSpec featureSetSpec, + List featureReferences) + throws InvalidProtocolBufferException, ExecutionException { + + List values = sendMultiGet(redisKeys); + List featureRows = new ArrayList<>(); + + FeatureRow.Builder nullFeatureRowBuilder = + FeatureRow.newBuilder().setFeatureSet(generateFeatureSetStringRef(featureSetSpec)); + for (FeatureReference featureReference : featureReferences) { + nullFeatureRowBuilder.addFields(Field.newBuilder().setName(featureReference.getName())); + } + + for (int i = 0; i < values.size(); i++) { + + byte[] value = values.get(i); + if (value == null) { + featureRows.add(nullFeatureRowBuilder.build()); + continue; + } + + FeatureRow featureRow = FeatureRow.parseFrom(value); + String featureSetRef = redisKeys.get(i).getFeatureSet(); + FeatureRowDecoder decoder = new FeatureRowDecoder(featureSetRef, featureSetSpec); + if (decoder.isEncoded(featureRow)) { + if (decoder.isEncodingValid(featureRow)) { + featureRow = decoder.decode(featureRow); + } else { + featureRows.add(nullFeatureRowBuilder.build()); + continue; + } + } + + featureRows.add(featureRow); + } + return featureRows; + } + + /** + * Send a list of get request as an mget + * + * @param keys list of {@link RedisKey} + * @return list of {@link FeatureRow} in primitive byte representation for each {@link RedisKey} + */ + private List sendMultiGet(List keys) { + try { + byte[][] binaryKeys = + keys.stream() + .map(AbstractMessageLite::toByteArray) + .collect(Collectors.toList()) + .toArray(new byte[0][0]); + return syncCommands.mget(binaryKeys).stream() + .map(keyValue -> keyValue.getValueOrElse(null)) + .collect(Collectors.toList()); + } catch (Exception e) { + throw Status.NOT_FOUND + .withDescription("Unable to retrieve feature from Redis") + .withCause(e) + .asRuntimeException(); + } + } + + // TODO: Refactor this out to common package? + private static String generateFeatureSetStringRef(FeatureSetSpec featureSetSpec) { + String ref = String.format("%s/%s", featureSetSpec.getProject(), featureSetSpec.getName()); + if (featureSetSpec.getVersion() > 0) { + return ref + String.format(":%d", featureSetSpec.getVersion()); + } + return ref; + } +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisCustomIO.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisCustomIO.java new file mode 100644 index 0000000000..e53861b99d --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisCustomIO.java @@ -0,0 +1,292 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.write; + +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.FeatureSetProto.FeatureSpec; +import feast.core.StoreProto.Store.RedisConfig; +import feast.storage.RedisProto.RedisKey; +import feast.storage.RedisProto.RedisKey.Builder; +import feast.storage.api.write.FailedElement; +import feast.storage.api.write.WriteResult; +import feast.storage.common.retry.Retriable; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import feast.types.ValueProto; +import io.lettuce.core.RedisConnectionException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RedisCustomIO { + + private static final int DEFAULT_BATCH_SIZE = 1000; + private static final int DEFAULT_TIMEOUT = 2000; + + private static TupleTag successfulInsertsTag = new TupleTag<>("successfulInserts") {}; + private static TupleTag failedInsertsTupleTag = new TupleTag<>("failedInserts") {}; + + private static final Logger log = LoggerFactory.getLogger(RedisCustomIO.class); + + private RedisCustomIO() {} + + public static Write write(RedisConfig redisConfig, Map featureSetSpecs) { + return new Write(redisConfig, featureSetSpecs); + } + + /** ServingStoreWrite data to a Redis server. */ + public static class Write extends PTransform, WriteResult> { + + private Map featureSetSpecs; + private RedisConfig redisConfig; + private int batchSize; + private int timeout; + + public Write(RedisConfig redisConfig, Map featureSetSpecs) { + + this.redisConfig = redisConfig; + this.featureSetSpecs = featureSetSpecs; + } + + public Write withBatchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Write withTimeout(int timeout) { + this.timeout = timeout; + return this; + } + + @Override + public WriteResult expand(PCollection input) { + PCollectionTuple redisWrite = + input.apply( + ParDo.of(new WriteDoFn(redisConfig, featureSetSpecs)) + .withOutputTags(successfulInsertsTag, TupleTagList.of(failedInsertsTupleTag))); + return WriteResult.in( + input.getPipeline(), + redisWrite.get(successfulInsertsTag), + redisWrite.get(failedInsertsTupleTag)); + } + + public static class WriteDoFn extends DoFn { + + private final List featureRows = new ArrayList<>(); + private Map featureSetSpecs; + private int batchSize = DEFAULT_BATCH_SIZE; + private int timeout = DEFAULT_TIMEOUT; + private RedisIngestionClient redisIngestionClient; + + WriteDoFn(RedisConfig config, Map featureSetSpecs) { + + this.redisIngestionClient = new RedisStandaloneIngestionClient(config); + this.featureSetSpecs = featureSetSpecs; + } + + public WriteDoFn withBatchSize(int batchSize) { + if (batchSize > 0) { + this.batchSize = batchSize; + } + return this; + } + + public WriteDoFn withTimeout(int timeout) { + if (timeout > 0) { + this.timeout = timeout; + } + return this; + } + + @Setup + public void setup() { + this.redisIngestionClient.setup(); + } + + @StartBundle + public void startBundle() { + try { + redisIngestionClient.connect(); + } catch (RedisConnectionException e) { + log.error("Connection to redis cannot be established ", e); + } + featureRows.clear(); + } + + private void executeBatch() throws Exception { + this.redisIngestionClient + .getBackOffExecutor() + .execute( + new Retriable() { + @Override + public void execute() throws ExecutionException, InterruptedException { + if (!redisIngestionClient.isConnected()) { + redisIngestionClient.connect(); + } + featureRows.forEach( + row -> { + redisIngestionClient.set(getKey(row), getValue(row)); + }); + redisIngestionClient.sync(); + } + + @Override + public Boolean isExceptionRetriable(Exception e) { + return e instanceof RedisConnectionException; + } + + @Override + public void cleanUpAfterFailure() {} + }); + } + + private FailedElement toFailedElement( + FeatureRow featureRow, Exception exception, String jobName) { + return FailedElement.newBuilder() + .setJobName(jobName) + .setTransformName("RedisCustomIO") + .setPayload(featureRow.toString()) + .setErrorMessage(exception.getMessage()) + .setStackTrace(ExceptionUtils.getStackTrace(exception)) + .build(); + } + + private byte[] getKey(FeatureRow featureRow) { + FeatureSetSpec featureSetSpec = featureSetSpecs.get(featureRow.getFeatureSet()); + List entityNames = + featureSetSpec.getEntitiesList().stream() + .map(EntitySpec::getName) + .sorted() + .collect(Collectors.toList()); + + Map entityFields = new HashMap<>(); + Builder redisKeyBuilder = RedisKey.newBuilder().setFeatureSet(featureRow.getFeatureSet()); + for (Field field : featureRow.getFieldsList()) { + if (entityNames.contains(field.getName())) { + entityFields.putIfAbsent( + field.getName(), + Field.newBuilder().setName(field.getName()).setValue(field.getValue()).build()); + } + } + for (String entityName : entityNames) { + redisKeyBuilder.addEntities(entityFields.get(entityName)); + } + return redisKeyBuilder.build().toByteArray(); + } + + private byte[] getValue(FeatureRow featureRow) { + FeatureSetSpec spec = featureSetSpecs.get(featureRow.getFeatureSet()); + + List featureNames = + spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList()); + Map fieldValueOnlyMap = + featureRow.getFieldsList().stream() + .filter(field -> featureNames.contains(field.getName())) + .distinct() + .collect( + Collectors.toMap( + Field::getName, + field -> Field.newBuilder().setValue(field.getValue()).build())); + + List values = + featureNames.stream() + .sorted() + .map( + featureName -> + fieldValueOnlyMap.getOrDefault( + featureName, + Field.newBuilder() + .setValue(ValueProto.Value.getDefaultInstance()) + .build())) + .collect(Collectors.toList()); + + return FeatureRow.newBuilder() + .setEventTimestamp(featureRow.getEventTimestamp()) + .addAllFields(values) + .build() + .toByteArray(); + } + + @ProcessElement + public void processElement(ProcessContext context) { + FeatureRow featureRow = context.element(); + featureRows.add(featureRow); + if (featureRows.size() >= batchSize) { + try { + executeBatch(); + featureRows.forEach(row -> context.output(successfulInsertsTag, row)); + featureRows.clear(); + } catch (Exception e) { + featureRows.forEach( + failedMutation -> { + FailedElement failedElement = + toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName()); + context.output(failedInsertsTupleTag, failedElement); + }); + featureRows.clear(); + } + } + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) + throws IOException, InterruptedException { + if (featureRows.size() > 0) { + try { + executeBatch(); + featureRows.forEach( + row -> + context.output( + successfulInsertsTag, row, Instant.now(), GlobalWindow.INSTANCE)); + featureRows.clear(); + } catch (Exception e) { + featureRows.forEach( + failedMutation -> { + FailedElement failedElement = + toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName()); + context.output( + failedInsertsTupleTag, failedElement, Instant.now(), GlobalWindow.INSTANCE); + }); + featureRows.clear(); + } + } + } + + @Teardown + public void teardown() { + redisIngestionClient.shutdown(); + } + } + } +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisFeatureSink.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisFeatureSink.java new file mode 100644 index 0000000000..2f566bb78c --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisFeatureSink.java @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.write; + +import com.google.auto.value.AutoValue; +import feast.core.FeatureSetProto.FeatureSet; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.StoreProto.Store.RedisConfig; +import feast.storage.api.write.FeatureSink; +import feast.storage.api.write.WriteResult; +import feast.types.FeatureRowProto.FeatureRow; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisConnectionException; +import io.lettuce.core.RedisURI; +import java.util.Map; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; + +@AutoValue +public abstract class RedisFeatureSink implements FeatureSink { + + public abstract RedisConfig getRedisConfig(); + + public abstract Map getFeatureSetSpecs(); + + public abstract Builder toBuilder(); + + public static Builder builder() { + return new AutoValue_RedisFeatureSink.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setRedisConfig(RedisConfig redisConfig); + + public abstract Builder setFeatureSetSpecs(Map featureSetSpecs); + + public abstract RedisFeatureSink build(); + } + + @Override + public void prepareWrite(FeatureSet featureSet) { + RedisClient redisClient = + RedisClient.create(RedisURI.create(getRedisConfig().getHost(), getRedisConfig().getPort())); + try { + redisClient.connect(); + } catch (RedisConnectionException e) { + throw new RuntimeException( + String.format( + "Failed to connect to Redis at host: '%s' port: '%d'. Please check that your Redis is running and accessible from Feast.", + getRedisConfig().getHost(), getRedisConfig().getPort())); + } + redisClient.shutdown(); + } + + @Override + public PTransform, WriteResult> write() { + return new RedisCustomIO.Write(getRedisConfig(), getFeatureSetSpecs()); + } +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisIngestionClient.java new file mode 100644 index 0000000000..7004b94282 --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisIngestionClient.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.write; + +import feast.storage.common.retry.BackOffExecutor; +import java.io.Serializable; + +public interface RedisIngestionClient extends Serializable { + + void setup(); + + BackOffExecutor getBackOffExecutor(); + + void shutdown(); + + void connect(); + + boolean isConnected(); + + void sync(); + + void pexpire(byte[] key, Long expiryMillis); + + void append(byte[] key, byte[] value); + + void set(byte[] key, byte[] value); + + void lpush(byte[] key, byte[] value); + + void rpush(byte[] key, byte[] value); + + void sadd(byte[] key, byte[] value); + + void zadd(byte[] key, Long score, byte[] value); +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisStandaloneIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisStandaloneIngestionClient.java new file mode 100644 index 0000000000..583eade2aa --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/write/RedisStandaloneIngestionClient.java @@ -0,0 +1,125 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.write; + +import com.google.common.collect.Lists; +import feast.core.StoreProto; +import feast.storage.common.retry.BackOffExecutor; +import io.lettuce.core.LettuceFutures; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisFuture; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.async.RedisAsyncCommands; +import io.lettuce.core.codec.ByteArrayCodec; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.joda.time.Duration; + +public class RedisStandaloneIngestionClient implements RedisIngestionClient { + private final String host; + private final int port; + private final BackOffExecutor backOffExecutor; + private RedisClient redisclient; + private static final int DEFAULT_TIMEOUT = 2000; + private StatefulRedisConnection connection; + private RedisAsyncCommands commands; + private List futures = Lists.newArrayList(); + + public RedisStandaloneIngestionClient(StoreProto.Store.RedisConfig redisConfig) { + this.host = redisConfig.getHost(); + this.port = redisConfig.getPort(); + long backoffMs = redisConfig.getInitialBackoffMs() > 0 ? redisConfig.getInitialBackoffMs() : 1; + this.backOffExecutor = + new BackOffExecutor(redisConfig.getMaxRetries(), Duration.millis(backoffMs)); + } + + @Override + public void setup() { + this.redisclient = + RedisClient.create(new RedisURI(host, port, java.time.Duration.ofMillis(DEFAULT_TIMEOUT))); + } + + @Override + public BackOffExecutor getBackOffExecutor() { + return this.backOffExecutor; + } + + @Override + public void shutdown() { + this.redisclient.shutdown(); + } + + @Override + public void connect() { + if (!isConnected()) { + this.connection = this.redisclient.connect(new ByteArrayCodec()); + this.commands = connection.async(); + } + } + + @Override + public boolean isConnected() { + return connection != null; + } + + @Override + public void sync() { + // Wait for some time for futures to complete + // TODO: should this be configurable? + try { + LettuceFutures.awaitAll(60, TimeUnit.SECONDS, futures.toArray(new RedisFuture[0])); + } finally { + futures.clear(); + } + } + + @Override + public void pexpire(byte[] key, Long expiryMillis) { + commands.pexpire(key, expiryMillis); + } + + @Override + public void append(byte[] key, byte[] value) { + futures.add(commands.append(key, value)); + } + + @Override + public void set(byte[] key, byte[] value) { + futures.add(commands.set(key, value)); + } + + @Override + public void lpush(byte[] key, byte[] value) { + futures.add(commands.lpush(key, value)); + } + + @Override + public void rpush(byte[] key, byte[] value) { + futures.add(commands.rpush(key, value)); + } + + @Override + public void sadd(byte[] key, byte[] value) { + futures.add(commands.sadd(key, value)); + } + + @Override + public void zadd(byte[] key, Long score, byte[] value) { + futures.add(commands.zadd(key, score, value)); + } +} diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetrieverTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetrieverTest.java new file mode 100644 index 0000000000..9fabc5fae7 --- /dev/null +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retrieval/RedisOnlineRetrieverTest.java @@ -0,0 +1,275 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.retrieval; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.initMocks; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.protobuf.AbstractMessageLite; +import com.google.protobuf.Duration; +import com.google.protobuf.Timestamp; +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.FeatureSetProto.FeatureSpec; +import feast.serving.ServingAPIProto.FeatureReference; +import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; +import feast.storage.RedisProto.RedisKey; +import feast.storage.api.retrieval.FeatureSetRequest; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import feast.types.ValueProto.Value; +import io.lettuce.core.KeyValue; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; + +public class RedisOnlineRetrieverTest { + + @Mock StatefulRedisConnection connection; + + @Mock RedisCommands syncCommands; + + private RedisOnlineRetriever redisOnlineRetriever; + private byte[][] redisKeyList; + + @Before + public void setUp() { + initMocks(this); + when(connection.sync()).thenReturn(syncCommands); + redisOnlineRetriever = new RedisOnlineRetriever(connection); + redisKeyList = + Lists.newArrayList( + RedisKey.newBuilder() + .setFeatureSet("project/featureSet:1") + .addAllEntities( + Lists.newArrayList( + Field.newBuilder().setName("entity1").setValue(intValue(1)).build(), + Field.newBuilder().setName("entity2").setValue(strValue("a")).build())) + .build(), + RedisKey.newBuilder() + .setFeatureSet("project/featureSet:1") + .addAllEntities( + Lists.newArrayList( + Field.newBuilder().setName("entity1").setValue(intValue(2)).build(), + Field.newBuilder().setName("entity2").setValue(strValue("b")).build())) + .build()) + .stream() + .map(AbstractMessageLite::toByteArray) + .collect(Collectors.toList()) + .toArray(new byte[0][0]); + } + + @Test + public void shouldReturnResponseWithValuesIfKeysPresent() { + FeatureSetRequest featureSetRequest = + FeatureSetRequest.newBuilder() + .setSpec(getFeatureSetSpec()) + .addFeatureReference( + FeatureReference.newBuilder() + .setName("feature1") + .setVersion(1) + .setProject("project") + .build()) + .addFeatureReference( + FeatureReference.newBuilder() + .setName("feature2") + .setVersion(1) + .setProject("project") + .build()) + .build(); + List entityRows = + ImmutableList.of( + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("entity1", intValue(1)) + .putFields("entity2", strValue("a")) + .build(), + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("entity1", intValue(2)) + .putFields("entity2", strValue("b")) + .build()); + + List featureRows = + Lists.newArrayList( + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setValue(intValue(1)).build(), + Field.newBuilder().setValue(intValue(1)).build())) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setValue(intValue(2)).build(), + Field.newBuilder().setValue(intValue(2)).build())) + .build()); + + List> featureRowBytes = + featureRows.stream() + .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) + .collect(Collectors.toList()); + + redisOnlineRetriever = new RedisOnlineRetriever(connection); + when(connection.sync()).thenReturn(syncCommands); + when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + + List> expected = + List.of( + Lists.newArrayList( + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").setValue(intValue(1)).build(), + Field.newBuilder().setName("feature2").setValue(intValue(1)).build())) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").setValue(intValue(2)).build(), + Field.newBuilder().setName("feature2").setValue(intValue(2)).build())) + .build())); + + List> actual = + redisOnlineRetriever.getOnlineFeatures(entityRows, List.of(featureSetRequest)); + assertThat(actual, equalTo(expected)); + } + + @Test + public void shouldReturnResponseWithUnsetValuesIfKeysNotPresent() { + FeatureSetRequest featureSetRequest = + FeatureSetRequest.newBuilder() + .setSpec(getFeatureSetSpec()) + .addFeatureReference( + FeatureReference.newBuilder() + .setName("feature1") + .setVersion(1) + .setProject("project") + .build()) + .addFeatureReference( + FeatureReference.newBuilder() + .setName("feature2") + .setVersion(1) + .setProject("project") + .build()) + .build(); + List entityRows = + ImmutableList.of( + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("entity1", intValue(1)) + .putFields("entity2", strValue("a")) + .build(), + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("entity1", intValue(2)) + .putFields("entity2", strValue("b")) + .build()); + + List featureRows = + Lists.newArrayList( + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setValue(intValue(1)).build(), + Field.newBuilder().setValue(intValue(1)).build())) + .build()); + + List> featureRowBytes = + featureRows.stream() + .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) + .collect(Collectors.toList()); + featureRowBytes.add(null); + + redisOnlineRetriever = new RedisOnlineRetriever(connection); + when(connection.sync()).thenReturn(syncCommands); + when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + + List> expected = + List.of( + Lists.newArrayList( + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").setValue(intValue(1)).build(), + Field.newBuilder().setName("feature2").setValue(intValue(1)).build())) + .build(), + FeatureRow.newBuilder() + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").build(), + Field.newBuilder().setName("feature2").build())) + .build())); + + List> actual = + redisOnlineRetriever.getOnlineFeatures(entityRows, List.of(featureSetRequest)); + assertThat(actual, equalTo(expected)); + } + + private Value intValue(int val) { + return Value.newBuilder().setInt64Val(val).build(); + } + + private Value strValue(String val) { + return Value.newBuilder().setStringVal(val).build(); + } + + private FeatureSetSpec getFeatureSetSpec() { + return FeatureSetSpec.newBuilder() + .setProject("project") + .setName("featureSet") + .setVersion(1) + .addEntities(EntitySpec.newBuilder().setName("entity1")) + .addEntities(EntitySpec.newBuilder().setName("entity2")) + .addFeatures(FeatureSpec.newBuilder().setName("feature1")) + .addFeatures(FeatureSpec.newBuilder().setName("feature2")) + .setMaxAge(Duration.newBuilder().setSeconds(30)) // default + .build(); + } + + private FeatureSetSpec getFeatureSetSpecWithNoMaxAge() { + return FeatureSetSpec.newBuilder() + .setProject("project") + .setName("featureSet") + .setVersion(1) + .addEntities(EntitySpec.newBuilder().setName("entity1")) + .addEntities(EntitySpec.newBuilder().setName("entity2")) + .addFeatures(FeatureSpec.newBuilder().setName("feature1")) + .addFeatures(FeatureSpec.newBuilder().setName("feature2")) + .setMaxAge(Duration.newBuilder().setSeconds(0).setNanos(0).build()) + .build(); + } +} diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java new file mode 100644 index 0000000000..66aba44bc2 --- /dev/null +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.test; + +import java.io.IOException; +import redis.embedded.RedisServer; + +public class TestUtil { + public static class LocalRedis { + + private static RedisServer server; + + /** + * Start local Redis for used in testing at "localhost" + * + * @param port port number + * @throws IOException if Redis failed to start + */ + public static void start(int port) throws IOException { + server = new RedisServer(port); + server.start(); + } + + public static void stop() { + if (server != null) { + server.stop(); + } + } + } +} diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/write/RedisFeatureSinkTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/write/RedisFeatureSinkTest.java new file mode 100644 index 0000000000..ddabed8fad --- /dev/null +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/write/RedisFeatureSinkTest.java @@ -0,0 +1,479 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.write; + +import static feast.storage.common.testing.TestUtil.field; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Timestamp; +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.FeatureSetProto.FeatureSpec; +import feast.core.StoreProto; +import feast.core.StoreProto.Store.RedisConfig; +import feast.storage.RedisProto.RedisKey; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import feast.types.ValueProto.Value; +import feast.types.ValueProto.ValueType.Enum; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisStringCommands; +import io.lettuce.core.codec.ByteArrayCodec; +import java.io.IOException; +import java.util.*; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import redis.embedded.Redis; +import redis.embedded.RedisServer; + +public class RedisFeatureSinkTest { + @Rule public transient TestPipeline p = TestPipeline.create(); + + private static String REDIS_HOST = "localhost"; + private static int REDIS_PORT = 51234; + private Redis redis; + private RedisClient redisClient; + private RedisStringCommands sync; + + private RedisFeatureSink redisFeatureSink; + + @Before + public void setUp() throws IOException { + redis = new RedisServer(REDIS_PORT); + redis.start(); + redisClient = + RedisClient.create(new RedisURI(REDIS_HOST, REDIS_PORT, java.time.Duration.ofMillis(2000))); + StatefulRedisConnection connection = redisClient.connect(new ByteArrayCodec()); + sync = connection.sync(); + + FeatureSetSpec spec1 = + FeatureSetSpec.newBuilder() + .setName("fs") + .setVersion(1) + .setProject("myproject") + .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build()) + .build(); + + FeatureSetSpec spec2 = + FeatureSetSpec.newBuilder() + .setName("feature_set") + .setProject("myproject") + .setVersion(1) + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_primary") + .setValueType(Enum.INT32) + .build()) + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_secondary") + .setValueType(Enum.STRING) + .build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build()) + .build(); + + Map specMap = + ImmutableMap.of("myproject/fs:1", spec1, "myproject/feature_set:1", spec2); + StoreProto.Store.RedisConfig redisConfig = + StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build(); + + redisFeatureSink = + RedisFeatureSink.builder().setFeatureSetSpecs(specMap).setRedisConfig(redisConfig).build(); + } + + @After + public void teardown() { + redisClient.shutdown(); + redis.stop(); + } + + @Test + public void shouldWriteToRedis() { + + HashMap kvs = new LinkedHashMap<>(); + kvs.put( + RedisKey.newBuilder() + .setFeatureSet("myproject/fs:1") + .addEntities(field("entity", 1, Enum.INT64)) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.getDefaultInstance()) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) + .build()); + kvs.put( + RedisKey.newBuilder() + .setFeatureSet("myproject/fs:1") + .addEntities(field("entity", 2, Enum.INT64)) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.getDefaultInstance()) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("two"))) + .build()); + + List featureRows = + ImmutableList.of( + FeatureRow.newBuilder() + .setFeatureSet("myproject/fs:1") + .addFields(field("entity", 1, Enum.INT64)) + .addFields(field("feature", "one", Enum.STRING)) + .build(), + FeatureRow.newBuilder() + .setFeatureSet("myproject/fs:1") + .addFields(field("entity", 2, Enum.INT64)) + .addFields(field("feature", "two", Enum.STRING)) + .build()); + + p.apply(Create.of(featureRows)).apply(redisFeatureSink.write()); + p.run(); + + kvs.forEach( + (key, value) -> { + byte[] actual = sync.get(key.toByteArray()); + assertThat(actual, equalTo(value.toByteArray())); + }); + } + + @Test(timeout = 10000) + public void shouldRetryFailConnection() throws InterruptedException { + RedisConfig redisConfig = + RedisConfig.newBuilder() + .setHost(REDIS_HOST) + .setPort(REDIS_PORT) + .setMaxRetries(4) + .setInitialBackoffMs(2000) + .build(); + redisFeatureSink = redisFeatureSink.toBuilder().setRedisConfig(redisConfig).build(); + + HashMap kvs = new LinkedHashMap<>(); + kvs.put( + RedisKey.newBuilder() + .setFeatureSet("myproject/fs:1") + .addEntities(field("entity", 1, Enum.INT64)) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.getDefaultInstance()) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) + .build()); + + List featureRows = + ImmutableList.of( + FeatureRow.newBuilder() + .setFeatureSet("myproject/fs:1") + .addFields(field("entity", 1, Enum.INT64)) + .addFields(field("feature", "one", Enum.STRING)) + .build()); + + PCollection failedElementCount = + p.apply(Create.of(featureRows)) + .apply(redisFeatureSink.write()) + .getFailedInserts() + .apply(Count.globally()); + + redis.stop(); + final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1); + ScheduledFuture scheduledRedisRestart = + redisRestartExecutor.schedule( + () -> { + redis.start(); + }, + 3, + TimeUnit.SECONDS); + + PAssert.that(failedElementCount).containsInAnyOrder(0L); + p.run(); + scheduledRedisRestart.cancel(true); + + kvs.forEach( + (key, value) -> { + byte[] actual = sync.get(key.toByteArray()); + assertThat(actual, equalTo(value.toByteArray())); + }); + } + + @Test + public void shouldProduceFailedElementIfRetryExceeded() { + + RedisConfig redisConfig = + RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT + 1).build(); + redisFeatureSink = redisFeatureSink.toBuilder().setRedisConfig(redisConfig).build(); + + HashMap kvs = new LinkedHashMap<>(); + kvs.put( + RedisKey.newBuilder() + .setFeatureSet("myproject/fs:1") + .addEntities(field("entity", 1, Enum.INT64)) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.getDefaultInstance()) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) + .build()); + + List featureRows = + ImmutableList.of( + FeatureRow.newBuilder() + .setFeatureSet("myproject/fs:1") + .addFields(field("entity", 1, Enum.INT64)) + .addFields(field("feature", "one", Enum.STRING)) + .build()); + + PCollection failedElementCount = + p.apply(Create.of(featureRows)) + .apply(redisFeatureSink.write()) + .getFailedInserts() + .apply(Count.globally()); + + redis.stop(); + PAssert.that(failedElementCount).containsInAnyOrder(1L); + p.run(); + } + + @Test + public void shouldConvertRowWithDuplicateEntitiesToValidKey() { + + FeatureRow offendingRow = + FeatureRow.newBuilder() + .setFeatureSet("myproject/feature_set:1") + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addFields( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(2))) + .addFields( + Field.newBuilder() + .setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .addFields( + Field.newBuilder() + .setName("feature_1") + .setValue(Value.newBuilder().setStringVal("strValue1"))) + .addFields( + Field.newBuilder() + .setName("feature_2") + .setValue(Value.newBuilder().setInt64Val(1001))) + .build(); + + RedisKey expectedKey = + RedisKey.newBuilder() + .setFeatureSet("myproject/feature_set:1") + .addEntities( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addEntities( + Field.newBuilder() + .setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .build(); + + FeatureRow expectedValue = + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) + .build(); + + p.apply(Create.of(offendingRow)).apply(redisFeatureSink.write()); + + p.run(); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); + } + + @Test + public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { + FeatureRow offendingRow = + FeatureRow.newBuilder() + .setFeatureSet("myproject/feature_set:1") + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields( + Field.newBuilder() + .setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .addFields( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addFields( + Field.newBuilder() + .setName("feature_2") + .setValue(Value.newBuilder().setInt64Val(1001))) + .addFields( + Field.newBuilder() + .setName("feature_1") + .setValue(Value.newBuilder().setStringVal("strValue1"))) + .build(); + + RedisKey expectedKey = + RedisKey.newBuilder() + .setFeatureSet("myproject/feature_set:1") + .addEntities( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addEntities( + Field.newBuilder() + .setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .build(); + + List expectedFields = + Arrays.asList( + Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")).build(), + Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)).build()); + FeatureRow expectedValue = + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addAllFields(expectedFields) + .build(); + + p.apply(Create.of(offendingRow)).apply(redisFeatureSink.write()); + + p.run(); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); + } + + @Test + public void shouldMergeDuplicateFeatureFields() { + FeatureRow featureRowWithDuplicatedFeatureFields = + FeatureRow.newBuilder() + .setFeatureSet("myproject/feature_set:1") + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addFields( + Field.newBuilder() + .setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .addFields( + Field.newBuilder() + .setName("feature_1") + .setValue(Value.newBuilder().setStringVal("strValue1"))) + .addFields( + Field.newBuilder() + .setName("feature_1") + .setValue(Value.newBuilder().setStringVal("strValue1"))) + .addFields( + Field.newBuilder() + .setName("feature_2") + .setValue(Value.newBuilder().setInt64Val(1001))) + .build(); + + RedisKey expectedKey = + RedisKey.newBuilder() + .setFeatureSet("myproject/feature_set:1") + .addEntities( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addEntities( + Field.newBuilder() + .setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .build(); + + FeatureRow expectedValue = + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) + .build(); + + p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.write()); + + p.run(); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); + } + + @Test + public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { + FeatureRow featureRowWithDuplicatedFeatureFields = + FeatureRow.newBuilder() + .setFeatureSet("myproject/feature_set:1") + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addFields( + Field.newBuilder() + .setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .addFields( + Field.newBuilder() + .setName("feature_1") + .setValue(Value.newBuilder().setStringVal("strValue1"))) + .build(); + + RedisKey expectedKey = + RedisKey.newBuilder() + .setFeatureSet("myproject/feature_set:1") + .addEntities( + Field.newBuilder() + .setName("entity_id_primary") + .setValue(Value.newBuilder().setInt32Val(1))) + .addEntities( + Field.newBuilder() + .setName("entity_id_secondary") + .setValue(Value.newBuilder().setStringVal("a"))) + .build(); + + FeatureRow expectedValue = + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) + .addFields(Field.newBuilder().setValue(Value.getDefaultInstance())) + .build(); + + p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.write()); + + p.run(); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); + } +}