From 4bac06d8613cba26fe31504aeda192dd0f2f0582 Mon Sep 17 00:00:00 2001 From: Pradithya Aria Date: Mon, 14 Jan 2019 11:33:42 +0800 Subject: [PATCH] Check the size of result against deduplicated request --- .../java/feast/core/service/SpecService.java | 26 +++++--- .../feast/core/service/SpecServiceTest.java | 65 ++++++++++++++++++- 2 files changed, 80 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/feast/core/service/SpecService.java b/core/src/main/java/feast/core/service/SpecService.java index ba6d0b391e..c7c663092d 100644 --- a/core/src/main/java/feast/core/service/SpecService.java +++ b/core/src/main/java/feast/core/service/SpecService.java @@ -18,6 +18,7 @@ package feast.core.service; import com.google.common.base.Strings; +import com.google.common.collect.Sets; import com.google.protobuf.util.JsonFormat; import feast.core.dao.EntityInfoRepository; import feast.core.dao.FeatureGroupInfoRepository; @@ -38,6 +39,7 @@ import feast.specs.FeatureGroupSpecProto.FeatureGroupSpec; import feast.specs.FeatureSpecProto.FeatureSpec; import feast.specs.StorageSpecProto.StorageSpec; +import java.util.Set; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -83,8 +85,10 @@ public List getEntities(List ids) { if (ids.size() == 0) { throw new IllegalArgumentException("ids cannot be empty"); } - List entityInfos = this.entityInfoRepository.findAllById(ids); - if (entityInfos.size() < ids.size()) { + Set dedupIds = Sets.newHashSet(ids); + + List entityInfos = this.entityInfoRepository.findAllById(dedupIds); + if (entityInfos.size() < dedupIds.size()) { throw new RetrievalException( "unable to retrieve all entities requested"); // TODO: check and return exactly which ones } @@ -113,8 +117,10 @@ public List getFeatures(List ids) { if (ids.size() == 0) { throw new IllegalArgumentException("ids cannot be empty"); } - List featureInfos = this.featureInfoRepository.findAllById(ids); - if (featureInfos.size() < ids.size()) { + Set dedupIds = Sets.newHashSet(ids); + + List featureInfos = this.featureInfoRepository.findAllById(dedupIds); + if (featureInfos.size() < dedupIds.size()) { throw new RetrievalException( "unable to retrieve all features requested"); // TODO: check and return exactly which ones } @@ -143,8 +149,10 @@ public List getFeatureGroups(List ids) { if (ids.size() == 0) { throw new IllegalArgumentException("ids cannot be empty"); } - List featureGroupInfos = this.featureGroupInfoRepository.findAllById(ids); - if (featureGroupInfos.size() < ids.size()) { + Set dedupIds = Sets.newHashSet(ids); + + List featureGroupInfos = this.featureGroupInfoRepository.findAllById(dedupIds); + if (featureGroupInfos.size() < dedupIds.size()) { throw new RetrievalException( "unable to retrieve all feature groups requested"); // TODO: check and return exactly // which ones @@ -174,8 +182,10 @@ public List getStorage(List ids) { if (ids.size() == 0) { throw new IllegalArgumentException("ids cannot be empty"); } - List storageInfos = this.storageInfoRepository.findAllById(ids); - if (storageInfos.size() < ids.size()) { + Set dedupIds = Sets.newHashSet(ids); + + List storageInfos = this.storageInfoRepository.findAllById(dedupIds); + if (storageInfos.size() < dedupIds.size()) { throw new RetrievalException( "unable to retrieve all storage requested"); // TODO: check and return exactly which ones } diff --git a/core/src/test/java/feast/core/service/SpecServiceTest.java b/core/src/test/java/feast/core/service/SpecServiceTest.java index 531adf6b61..d4af688750 100644 --- a/core/src/test/java/feast/core/service/SpecServiceTest.java +++ b/core/src/test/java/feast/core/service/SpecServiceTest.java @@ -51,6 +51,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.initMocks; @@ -104,7 +105,26 @@ public void shouldGetEntitiesMatchingIds() { EntityInfo entity2 = newTestEntityInfo("entity2"); ArrayList ids = Lists.newArrayList("entity1", "entity2"); - when(entityInfoRepository.findAllById(ids)).thenReturn(Lists.newArrayList(entity1, entity2)); + when(entityInfoRepository.findAllById(any(Iterable.class))).thenReturn(Lists.newArrayList(entity1, entity2)); + SpecService specService = + new SpecService( + entityInfoRepository, + featureInfoRepository, + storageInfoRepository, + featureGroupInfoRepository, + schemaManager); + List actual = specService.getEntities(ids); + List expected = Lists.newArrayList(entity1, entity2); + assertThat(actual, equalTo(expected)); + } + + @Test + public void shouldDeduplicateGetEntities() { + EntityInfo entity1 = newTestEntityInfo("entity1"); + EntityInfo entity2 = newTestEntityInfo("entity2"); + + ArrayList ids = Lists.newArrayList("entity1", "entity2", "entity2"); + when(entityInfoRepository.findAllById(any(Iterable.class))).thenReturn(Lists.newArrayList(entity1, entity2)); SpecService specService = new SpecService( entityInfoRepository, @@ -161,7 +181,26 @@ public void shouldGetFeaturesMatchingIds() { FeatureInfo feature2 = newTestFeatureInfo("feature2"); ArrayList ids = Lists.newArrayList("entity.none.feature1", "entity.none.feature2"); - when(featureInfoRepository.findAllById(ids)).thenReturn(Lists.newArrayList(feature1, feature2)); + when(featureInfoRepository.findAllById(any(Iterable.class))).thenReturn(Lists.newArrayList(feature1, feature2)); + SpecService specService = + new SpecService( + entityInfoRepository, + featureInfoRepository, + storageInfoRepository, + featureGroupInfoRepository, + schemaManager); + List actual = specService.getFeatures(ids); + List expected = Lists.newArrayList(feature1, feature2); + assertThat(actual, equalTo(expected)); + } + + @Test + public void shouldDeduplicateGetFeature() { + FeatureInfo feature1 = newTestFeatureInfo("feature1"); + FeatureInfo feature2 = newTestFeatureInfo("feature2"); + + ArrayList ids = Lists.newArrayList("entity.none.feature1", "entity.none.feature2", "entity.none.feature2"); + when(featureInfoRepository.findAllById(any(Iterable.class))).thenReturn(Lists.newArrayList(feature1, feature2)); SpecService specService = new SpecService( entityInfoRepository, @@ -216,7 +255,27 @@ public void shouldGetStorageMatchingIds() { StorageInfo bqStorage = newTestStorageInfo("BIGQUERY1", "BIGQUERY"); ArrayList ids = Lists.newArrayList("REDIS1", "BIGQUERY1"); - when(storageInfoRepository.findAllById(ids)) + when(storageInfoRepository.findAllById(any(Iterable.class))) + .thenReturn(Lists.newArrayList(redisStorage, bqStorage)); + SpecService specService = + new SpecService( + entityInfoRepository, + featureInfoRepository, + storageInfoRepository, + featureGroupInfoRepository, + schemaManager); + List actual = specService.getStorage(ids); + List expected = Lists.newArrayList(redisStorage, bqStorage); + assertThat(actual, equalTo(expected)); + } + + @Test + public void shouldDeduplicateGetStorage() { + StorageInfo redisStorage = newTestStorageInfo("REDIS1", "REDIS"); + StorageInfo bqStorage = newTestStorageInfo("BIGQUERY1", "BIGQUERY"); + + ArrayList ids = Lists.newArrayList("REDIS1", "BIGQUERY1", "BIGQUERY1"); + when(storageInfoRepository.findAllById(any(Iterable.class))) .thenReturn(Lists.newArrayList(redisStorage, bqStorage)); SpecService specService = new SpecService(