From 056cfa1b21db4ff092b9d1f9c06f7300a4c9f4b7 Mon Sep 17 00:00:00 2001 From: Yongheng Lin Date: Thu, 30 Jun 2022 13:00:27 -0700 Subject: [PATCH] feat: Support retrieval from multiple feature views with different join keys (#2835) * feat: Support retrieving from multiple feature views Signed-off-by: Yongheng Lin * group by join keys instead of feature view Signed-off-by: Yongheng Lin * tolerate insufficient entities Signed-off-by: Yongheng Lin * mock registry.getEntityJoinKey Signed-off-by: Yongheng Lin * add integration test Signed-off-by: Yongheng Lin --- .../java/feast/serving/registry/Registry.java | 15 +++ .../serving/registry/RegistryRepository.java | 4 + .../service/OnlineServingServiceV2.java | 101 +++++++++++++++--- .../feast/serving/it/ServingBaseTests.java | 30 ++++++ .../service/OnlineServingServiceTest.java | 6 ++ 5 files changed, 140 insertions(+), 16 deletions(-) diff --git a/java/serving/src/main/java/feast/serving/registry/Registry.java b/java/serving/src/main/java/feast/serving/registry/Registry.java index bc953174ea..a7b28f7c66 100644 --- a/java/serving/src/main/java/feast/serving/registry/Registry.java +++ b/java/serving/src/main/java/feast/serving/registry/Registry.java @@ -33,6 +33,7 @@ public class Registry { private Map onDemandFeatureViewNameToSpec; private final Map featureServiceNameToSpec; + private final Map entityNameToJoinKey; Registry(RegistryProto.Registry registry) { this.registry = registry; @@ -60,6 +61,12 @@ public class Registry { .collect( Collectors.toMap( FeatureServiceProto.FeatureServiceSpec::getName, Function.identity())); + this.entityNameToJoinKey = + registry.getEntitiesList().stream() + .map(EntityProto.Entity::getSpec) + .collect( + Collectors.toMap( + EntityProto.EntitySpecV2::getName, EntityProto.EntitySpecV2::getJoinKey)); } public RegistryProto.Registry getRegistry() { @@ -115,4 +122,12 @@ public FeatureServiceProto.FeatureServiceSpec getFeatureServiceSpec(String name) } return spec; } + + public String getEntityJoinKey(String name) { + String joinKey = entityNameToJoinKey.get(name); + if (joinKey == null) { + throw new SpecRetrievalException(String.format("Unable to find entity with name: %s", name)); + } + return joinKey; + } } diff --git a/java/serving/src/main/java/feast/serving/registry/RegistryRepository.java b/java/serving/src/main/java/feast/serving/registry/RegistryRepository.java index 369493ee0f..023ec1a062 100644 --- a/java/serving/src/main/java/feast/serving/registry/RegistryRepository.java +++ b/java/serving/src/main/java/feast/serving/registry/RegistryRepository.java @@ -102,4 +102,8 @@ public Duration getMaxAge(ServingAPIProto.FeatureReferenceV2 featureReference) { public List getEntitiesList(ServingAPIProto.FeatureReferenceV2 featureReference) { return getFeatureViewSpec(featureReference).getEntitiesList(); } + + public String getEntityJoinKey(String name) { + return this.registry.getEntityJoinKey(name); + } } diff --git a/java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java b/java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java index 12e8a5b158..3751ee8119 100644 --- a/java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java +++ b/java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java @@ -34,7 +34,6 @@ import feast.serving.registry.RegistryRepository; import feast.serving.util.Metrics; import feast.storage.api.retriever.OnlineRetrieverV2; -import io.grpc.Status; import io.opentracing.Span; import io.opentracing.Tracer; import java.util.*; @@ -51,6 +50,11 @@ public class OnlineServingServiceV2 implements ServingServiceV2 { private final OnlineTransformationService onlineTransformationService; private final String project; + public static final String DUMMY_ENTITY_ID = "__dummy_id"; + public static final String DUMMY_ENTITY_VAL = ""; + public static final ValueProto.Value DUMMY_ENTITY_VALUE = + ValueProto.Value.newBuilder().setStringVal(DUMMY_ENTITY_VAL).build(); + public OnlineServingServiceV2( OnlineRetrieverV2 retriever, Tracer tracer, @@ -103,31 +107,18 @@ public ServingAPIProto.GetOnlineFeaturesResponse getOnlineFeatures( List> entityRows = getEntityRows(request); - List entityNames; - if (retrievedFeatureReferences.size() > 0) { - entityNames = this.registryRepository.getEntitiesList(retrievedFeatureReferences.get(0)); - } else { - throw new RuntimeException("Requested features list must not be empty"); - } - Span storageRetrievalSpan = tracer.buildSpan("storageRetrieval").start(); if (storageRetrievalSpan != null) { storageRetrievalSpan.setTag("entities", entityRows.size()); storageRetrievalSpan.setTag("features", retrievedFeatureReferences.size()); } + List> features = - retriever.getOnlineFeatures(entityRows, retrievedFeatureReferences, entityNames); + retrieveFeatures(retrievedFeatureReferences, entityRows); if (storageRetrievalSpan != null) { storageRetrievalSpan.finish(); } - if (features.size() != entityRows.size()) { - throw Status.INTERNAL - .withDescription( - "The no. of FeatureRow obtained from OnlineRetriever" - + "does not match no. of entityRow passed.") - .asRuntimeException(); - } Span postProcessingSpan = tracer.buildSpan("postProcessing").start(); @@ -255,6 +246,84 @@ private List> getEntityRows( return entityRows; } + private List> retrieveFeatures( + List featureReferences, List> entityRows) { + // Prepare feature reference to index mapping. This mapping will be used to arrange the + // retrieved features to the same order as in the input. + if (featureReferences.isEmpty()) { + throw new RuntimeException("Requested features list must not be empty."); + } + Map featureReferenceToIndexMap = + new HashMap<>(featureReferences.size()); + for (int i = 0; i < featureReferences.size(); i++) { + FeatureReferenceV2 featureReference = featureReferences.get(i); + if (featureReferenceToIndexMap.containsKey(featureReference)) { + throw new RuntimeException( + String.format( + "Found duplicate features %s:%s.", + featureReference.getFeatureViewName(), featureReference.getFeatureName())); + } + featureReferenceToIndexMap.put(featureReference, i); + } + + // Create placeholders for retrieved features. + List> features = new ArrayList<>(entityRows.size()); + for (int i = 0; i < entityRows.size(); i++) { + List featuresPerEntity = + new ArrayList<>(featureReferences.size()); + for (int j = 0; j < featureReferences.size(); j++) { + featuresPerEntity.add(null); + } + features.add(featuresPerEntity); + } + + // Group feature references by join keys. + Map> groupNameToFeatureReferencesMap = + featureReferences.stream() + .collect( + Collectors.groupingBy( + featureReference -> + this.registryRepository.getEntitiesList(featureReference).stream() + .map(this.registryRepository::getEntityJoinKey) + .sorted() + .collect(Collectors.joining(",")))); + + // Retrieve features one group at a time. + for (List featureReferencesPerGroup : + groupNameToFeatureReferencesMap.values()) { + List entityNames = + this.registryRepository.getEntitiesList(featureReferencesPerGroup.get(0)); + List> entityRowsPerGroup = new ArrayList<>(entityRows.size()); + for (Map entityRow : entityRows) { + Map entityRowPerGroup = new HashMap<>(); + entityNames.stream() + .map(this.registryRepository::getEntityJoinKey) + .forEach( + joinKey -> { + if (joinKey.equals(DUMMY_ENTITY_ID)) { + entityRowPerGroup.put(joinKey, DUMMY_ENTITY_VALUE); + } else { + ValueProto.Value value = entityRow.get(joinKey); + if (value != null) { + entityRowPerGroup.put(joinKey, value); + } + } + }); + entityRowsPerGroup.add(entityRowPerGroup); + } + List> featuresPerGroup = + retriever.getOnlineFeatures(entityRowsPerGroup, featureReferencesPerGroup, entityNames); + for (int i = 0; i < featuresPerGroup.size(); i++) { + for (int j = 0; j < featureReferencesPerGroup.size(); j++) { + int k = featureReferenceToIndexMap.get(featureReferencesPerGroup.get(j)); + features.get(i).set(k, featuresPerGroup.get(i).get(j)); + } + } + } + + return features; + } + private void populateOnDemandFeatures( List onDemandFeatureReferences, List onDemandFeatureSources, diff --git a/java/serving/src/test/java/feast/serving/it/ServingBaseTests.java b/java/serving/src/test/java/feast/serving/it/ServingBaseTests.java index 30cba0cb06..66987e8c0d 100644 --- a/java/serving/src/test/java/feast/serving/it/ServingBaseTests.java +++ b/java/serving/src/test/java/feast/serving/it/ServingBaseTests.java @@ -172,5 +172,35 @@ public void shouldGetOnlineFeaturesWithStringEntity() { } } + @Test + public void shouldGetOnlineFeaturesFromAllFeatureViews() { + Map entityRows = + ImmutableMap.of( + "entity", + ValueProto.RepeatedValue.newBuilder() + .addVal(DataGenerator.createStrValue("key-1")) + .build(), + "driver_id", + ValueProto.RepeatedValue.newBuilder() + .addVal(DataGenerator.createInt64Value(1005)) + .build()); + + ImmutableList featureReferences = + ImmutableList.of( + "feature_view_0:feature_0", + "feature_view_0:feature_1", + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:avg_daily_trips"); + + ServingAPIProto.GetOnlineFeaturesRequest req = + TestUtils.createOnlineFeatureRequest(featureReferences, entityRows); + + ServingAPIProto.GetOnlineFeaturesResponse resp = servingStub.getOnlineFeatures(req); + + for (final int featureIdx : List.of(0, 1, 2, 3)) { + assertEquals(FieldStatus.PRESENT, resp.getResults(featureIdx).getStatuses(0)); + } + } + abstract void updateRegistryFile(RegistryProto.Registry registry); } diff --git a/java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java b/java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java index 64d2e20c9b..933e38f056 100644 --- a/java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java +++ b/java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java @@ -170,6 +170,8 @@ public void shouldReturnResponseWithValuesAndMetadataIfKeysPresent() { .thenReturn(featureSpecs.get(0)); when(registry.getFeatureSpec(mockedFeatureRows.get(3).getFeatureReference())) .thenReturn(featureSpecs.get(1)); + when(registry.getEntityJoinKey("entity1")).thenReturn("entity1"); + when(registry.getEntityJoinKey("entity2")).thenReturn("entity2"); when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class)); @@ -237,6 +239,8 @@ public void shouldReturnResponseWithUnsetValuesAndMetadataIfKeysNotPresent() { .thenReturn(featureSpecs.get(0)); when(registry.getFeatureSpec(mockedFeatureRows.get(1).getFeatureReference())) .thenReturn(featureSpecs.get(1)); + when(registry.getEntityJoinKey("entity1")).thenReturn("entity1"); + when(registry.getEntityJoinKey("entity2")).thenReturn("entity2"); when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class)); @@ -314,6 +318,8 @@ public void shouldReturnResponseWithValuesAndMetadataIfMaxAgeIsExceeded() { .thenReturn(featureSpecs.get(1)); when(registry.getFeatureSpec(mockedFeatureRows.get(5).getFeatureReference())) .thenReturn(featureSpecs.get(0)); + when(registry.getEntityJoinKey("entity1")).thenReturn("entity1"); + when(registry.getEntityJoinKey("entity2")).thenReturn("entity2"); when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));