Skip to content

Commit

Permalink
Encode feature row before storing in Redis (feast-dev#530)
Browse files Browse the repository at this point in the history
* Encode feature row before storing in Redis

* Include encoding as part of RedisMutationDoFn

Co-authored-by: Khor Shu Heng <[email protected]>
  • Loading branch information
khorshuheng and khorshuheng committed Mar 16, 2020
1 parent 3dd6041 commit f3a4886
Show file tree
Hide file tree
Showing 8 changed files with 469 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@

import feast.core.FeatureSetProto.EntitySpec;
import feast.core.FeatureSetProto.FeatureSet;
import feast.core.FeatureSetProto.FeatureSetSpec;
import feast.core.FeatureSetProto.FeatureSpec;
import feast.storage.RedisProto.RedisKey;
import feast.storage.RedisProto.RedisKey.Builder;
import feast.store.serving.redis.RedisCustomIO.Method;
import feast.store.serving.redis.RedisCustomIO.RedisMutation;
import feast.types.FeatureRowProto.FeatureRow;
import feast.types.FieldProto.Field;
import feast.types.ValueProto;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -64,14 +67,45 @@ private RedisKey getKey(FeatureRow featureRow) {
return redisKeyBuilder.build();
}

private byte[] getValue(FeatureRow featureRow) {
FeatureSetSpec spec = featureSets.get(featureRow.getFeatureSet()).getSpec();

List<String> featureNames =
spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList());
Map<String, Field> fieldValueOnlyMap =
featureRow.getFieldsList().stream()
.filter(field -> featureNames.contains(field.getName()))
.distinct()
.collect(
Collectors.toMap(
Field::getName,
field -> Field.newBuilder().setValue(field.getValue()).build()));

List<Field> values =
featureNames.stream()
.sorted()
.map(
featureName ->
fieldValueOnlyMap.getOrDefault(
featureName,
Field.newBuilder().setValue(ValueProto.Value.getDefaultInstance()).build()))
.collect(Collectors.toList());

return FeatureRow.newBuilder()
.setEventTimestamp(featureRow.getEventTimestamp())
.addAllFields(values)
.build()
.toByteArray();
}

/** Output a redis mutation object for every feature in the feature row. */
@ProcessElement
public void processElement(ProcessContext context) {
FeatureRow featureRow = context.element();
try {
RedisKey key = getKey(featureRow);
RedisMutation redisMutation =
new RedisMutation(Method.SET, key.toByteArray(), featureRow.toByteArray(), null, null);
byte[] key = getKey(featureRow).toByteArray();
byte[] value = getValue(featureRow);
RedisMutation redisMutation = new RedisMutation(Method.SET, key, value, null, null);
context.output(redisMutation);
} catch (Exception e) {
log.error(e.getMessage(), e);
Expand Down
19 changes: 19 additions & 0 deletions ingestion/src/test/java/feast/ingestion/ImportJobTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import feast.test.TestUtil.LocalKafka;
import feast.test.TestUtil.LocalRedis;
import feast.types.FeatureRowProto.FeatureRow;
import feast.types.FieldProto;
import feast.types.ValueProto.ValueType.Enum;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisURI;
Expand All @@ -50,6 +51,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.PipelineResult.State;
Expand Down Expand Up @@ -189,6 +191,23 @@ public void runPipeline_ShouldWriteToRedisCorrectlyGivenValidSpecAndFeatureRow()
FeatureRow randomRow = TestUtil.createRandomFeatureRow(featureSet);
RedisKey redisKey = TestUtil.createRedisKey(featureSet, randomRow);
input.add(randomRow);
List<FieldProto.Field> fields =
randomRow.getFieldsList().stream()
.filter(
field ->
spec.getFeaturesList().stream()
.map(FeatureSpec::getName)
.collect(Collectors.toList())
.contains(field.getName()))
.map(field -> field.toBuilder().clearName().build())
.collect(Collectors.toList());
randomRow =
randomRow
.toBuilder()
.clearFields()
.addAllFields(fields)
.clearFeatureSet()
.build();
expected.put(redisKey, randomRow);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
import feast.types.FieldProto.Field;
import feast.types.ValueProto.Value;
import feast.types.ValueProto.ValueType.Enum;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.*;
import org.apache.beam.sdk.extensions.protobuf.ProtoCoder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
Expand Down Expand Up @@ -96,6 +93,14 @@ public void shouldConvertRowWithDuplicateEntitiesToValidKey() {
Field.newBuilder()
.setName("entity_id_secondary")
.setValue(Value.newBuilder().setStringVal("a")))
.addFields(
Field.newBuilder()
.setName("feature_1")
.setValue(Value.newBuilder().setStringVal("strValue1")))
.addFields(
Field.newBuilder()
.setName("feature_2")
.setValue(Value.newBuilder().setInt64Val(1001)))
.build();

PCollection<RedisMutation> output =
Expand All @@ -116,22 +121,29 @@ public void shouldConvertRowWithDuplicateEntitiesToValidKey() {
.setValue(Value.newBuilder().setStringVal("a")))
.build();

FeatureRow expectedValue =
FeatureRow.newBuilder()
.setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
.addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")))
.addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)))
.build();

PAssert.that(output)
.satisfies(
(SerializableFunction<Iterable<RedisMutation>, Void>)
input -> {
input.forEach(
rm -> {
assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray()));
assert (Arrays.equals(rm.getValue(), offendingRow.toByteArray()));
assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray()));
});
return null;
});
p.run();
}

@Test
public void shouldConvertRowWithOutOfOrderEntitiesToValidKey() {
public void shouldConvertRowWithOutOfOrderFieldsToValidKey() {
Map<String, FeatureSetProto.FeatureSet> featureSets = new HashMap<>();
featureSets.put("feature_set", fs);

Expand All @@ -147,6 +159,14 @@ public void shouldConvertRowWithOutOfOrderEntitiesToValidKey() {
Field.newBuilder()
.setName("entity_id_primary")
.setValue(Value.newBuilder().setInt32Val(1)))
.addFields(
Field.newBuilder()
.setName("feature_2")
.setValue(Value.newBuilder().setInt64Val(1001)))
.addFields(
Field.newBuilder()
.setName("feature_1")
.setValue(Value.newBuilder().setStringVal("strValue1")))
.build();

PCollection<RedisMutation> output =
Expand All @@ -167,14 +187,156 @@ public void shouldConvertRowWithOutOfOrderEntitiesToValidKey() {
.setValue(Value.newBuilder().setStringVal("a")))
.build();

List<Field> expectedFields =
Arrays.asList(
Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")).build(),
Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)).build());
FeatureRow expectedValue =
FeatureRow.newBuilder()
.setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
.addAllFields(expectedFields)
.build();

PAssert.that(output)
.satisfies(
(SerializableFunction<Iterable<RedisMutation>, Void>)
input -> {
input.forEach(
rm -> {
assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray()));
assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray()));
});
return null;
});
p.run();
}

@Test
public void shouldMergeDuplicateFeatureFields() {
Map<String, FeatureSetProto.FeatureSet> featureSets = new HashMap<>();
featureSets.put("feature_set", fs);

FeatureRow featureRowWithDuplicatedFeatureFields =
FeatureRow.newBuilder()
.setFeatureSet("feature_set")
.setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
.addFields(
Field.newBuilder()
.setName("entity_id_primary")
.setValue(Value.newBuilder().setInt32Val(1)))
.addFields(
Field.newBuilder()
.setName("entity_id_secondary")
.setValue(Value.newBuilder().setStringVal("a")))
.addFields(
Field.newBuilder()
.setName("feature_1")
.setValue(Value.newBuilder().setStringVal("strValue1")))
.addFields(
Field.newBuilder()
.setName("feature_1")
.setValue(Value.newBuilder().setStringVal("strValue1")))
.addFields(
Field.newBuilder()
.setName("feature_2")
.setValue(Value.newBuilder().setInt64Val(1001)))
.build();

PCollection<RedisMutation> output =
p.apply(Create.of(Collections.singletonList(featureRowWithDuplicatedFeatureFields)))
.setCoder(ProtoCoder.of(FeatureRow.class))
.apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSets)));

RedisKey expectedKey =
RedisKey.newBuilder()
.setFeatureSet("feature_set")
.addEntities(
Field.newBuilder()
.setName("entity_id_primary")
.setValue(Value.newBuilder().setInt32Val(1)))
.addEntities(
Field.newBuilder()
.setName("entity_id_secondary")
.setValue(Value.newBuilder().setStringVal("a")))
.build();

FeatureRow expectedValue =
FeatureRow.newBuilder()
.setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
.addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")))
.addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)))
.build();

PAssert.that(output)
.satisfies(
(SerializableFunction<Iterable<RedisMutation>, Void>)
input -> {
input.forEach(
rm -> {
assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray()));
assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray()));
});
return null;
});
p.run();
}

@Test
public void shouldPopulateMissingFeatureValuesWithDefaultInstance() {
Map<String, FeatureSetProto.FeatureSet> featureSets = new HashMap<>();
featureSets.put("feature_set", fs);

FeatureRow featureRowWithDuplicatedFeatureFields =
FeatureRow.newBuilder()
.setFeatureSet("feature_set")
.setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
.addFields(
Field.newBuilder()
.setName("entity_id_primary")
.setValue(Value.newBuilder().setInt32Val(1)))
.addFields(
Field.newBuilder()
.setName("entity_id_secondary")
.setValue(Value.newBuilder().setStringVal("a")))
.addFields(
Field.newBuilder()
.setName("feature_1")
.setValue(Value.newBuilder().setStringVal("strValue1")))
.build();

PCollection<RedisMutation> output =
p.apply(Create.of(Collections.singletonList(featureRowWithDuplicatedFeatureFields)))
.setCoder(ProtoCoder.of(FeatureRow.class))
.apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSets)));

RedisKey expectedKey =
RedisKey.newBuilder()
.setFeatureSet("feature_set")
.addEntities(
Field.newBuilder()
.setName("entity_id_primary")
.setValue(Value.newBuilder().setInt32Val(1)))
.addEntities(
Field.newBuilder()
.setName("entity_id_secondary")
.setValue(Value.newBuilder().setStringVal("a")))
.build();

FeatureRow expectedValue =
FeatureRow.newBuilder()
.setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
.addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")))
.addFields(Field.newBuilder().setValue(Value.getDefaultInstance()))
.build();

PAssert.that(output)
.satisfies(
(SerializableFunction<Iterable<RedisMutation>, Void>)
input -> {
input.forEach(
rm -> {
assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray()));
assert (Arrays.equals(rm.getValue(), offendingRow.toByteArray()));
assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray()));
});
return null;
});
Expand Down
Loading

0 comments on commit f3a4886

Please sign in to comment.