diff --git a/spring-cloud-gcp-data-datastore/src/main/java/com/google/cloud/spring/data/datastore/core/DatastoreTemplate.java b/spring-cloud-gcp-data-datastore/src/main/java/com/google/cloud/spring/data/datastore/core/DatastoreTemplate.java index fb951a5f22..67a5497c1e 100644 --- a/spring-cloud-gcp-data-datastore/src/main/java/com/google/cloud/spring/data/datastore/core/DatastoreTemplate.java +++ b/spring-cloud-gcp-data-datastore/src/main/java/com/google/cloud/spring/data/datastore/core/DatastoreTemplate.java @@ -16,6 +16,9 @@ package com.google.cloud.spring.data.datastore.core; +import com.google.cloud.datastore.AggregationQuery; +import com.google.cloud.datastore.AggregationResult; +import com.google.cloud.datastore.AggregationResults; import com.google.cloud.datastore.BaseEntity; import com.google.cloud.datastore.BaseKey; import com.google.cloud.datastore.Cursor; @@ -39,6 +42,7 @@ import com.google.cloud.datastore.StructuredQuery.Filter; import com.google.cloud.datastore.StructuredQuery.PropertyFilter; import com.google.cloud.datastore.Value; +import com.google.cloud.datastore.aggregation.Aggregation; import com.google.cloud.spring.data.datastore.core.convert.DatastoreEntityConverter; import com.google.cloud.spring.data.datastore.core.convert.ObjectToKeyFactory; import com.google.cloud.spring.data.datastore.core.mapping.DatastoreDataException; @@ -55,6 +59,7 @@ import com.google.cloud.spring.data.datastore.core.util.SliceUtil; import com.google.cloud.spring.data.datastore.core.util.ValueUtil; import com.google.cloud.spring.data.datastore.repository.query.DatastorePageable; +import com.google.common.collect.Iterables; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -247,7 +252,19 @@ private void performDelete(Key[] keys, Iterable ids, Iterable entities, Class en @Override public long count(Class entityClass) { - return findAllKeys(entityClass).length; + KeyQuery baseQuery = Query.newKeyQueryBuilder() + .setKind(getPersistentEntity(entityClass).kindName()) + .build(); + + AggregationQuery countAggregationQuery = Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregation(Aggregation.count().as("total_count")) + .build(); + + AggregationResults aggregationResults = getDatastoreReadWriter().runAggregation(countAggregationQuery); + maybeEmitEvent(new AfterQueryEvent(aggregationResults, countAggregationQuery)); + AggregationResult aggregationResult = Iterables.getOnlyElement(aggregationResults); + return aggregationResult.get("total_count"); } @Override diff --git a/spring-cloud-gcp-data-datastore/src/test/java/com/google/cloud/spring/data/datastore/core/DatastoreTemplateTests.java b/spring-cloud-gcp-data-datastore/src/test/java/com/google/cloud/spring/data/datastore/core/DatastoreTemplateTests.java index 21be338b14..b2c8b40431 100644 --- a/spring-cloud-gcp-data-datastore/src/test/java/com/google/cloud/spring/data/datastore/core/DatastoreTemplateTests.java +++ b/spring-cloud-gcp-data-datastore/src/test/java/com/google/cloud/spring/data/datastore/core/DatastoreTemplateTests.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.notNull; import static org.mockito.ArgumentMatchers.same; @@ -28,6 +29,9 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.cloud.datastore.AggregationQuery; +import com.google.cloud.datastore.AggregationResult; +import com.google.cloud.datastore.AggregationResults; import com.google.cloud.datastore.Cursor; import com.google.cloud.datastore.Datastore; import com.google.cloud.datastore.Datastore.TransactionCallable; @@ -49,6 +53,7 @@ import com.google.cloud.datastore.QueryResults; import com.google.cloud.datastore.StructuredQuery; import com.google.cloud.datastore.StructuredQuery.PropertyFilter; +import com.google.cloud.datastore.aggregation.Aggregation; import com.google.cloud.spring.core.util.MapBuilder; import com.google.cloud.spring.data.datastore.core.convert.DatastoreEntityConverter; import com.google.cloud.spring.data.datastore.core.convert.ObjectToKeyFactory; @@ -85,6 +90,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; +import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatchers; import org.mockito.InOrder; import org.mockito.Mockito; @@ -1043,22 +1049,28 @@ private boolean nextPageTest(boolean hasNextPage) { @Test void countTest() { - QueryResults queryResults = mock(QueryResults.class); - when(queryResults.getResultClass()).thenReturn((Class) Key.class); - doAnswer( - invocation -> { - Arrays.asList(this.key1, this.key2) - .iterator() - .forEachRemaining(invocation.getArgument(0)); - return null; - }) - .when(queryResults) - .forEachRemaining(any()); - when(this.datastore.run(Query.newKeyQueryBuilder().setKind("custom_test_kind").build())) - .thenReturn(queryResults); + AggregationResult aggregationResult = mock(AggregationResult.class); + AggregationResults aggregationResults = mock(AggregationResults.class); + when(aggregationResult.get("total_count")).thenReturn(2L); + when(aggregationResults.iterator()).thenReturn(List.of(aggregationResult).iterator()); + + KeyQuery baseQuery = Query.newKeyQueryBuilder().setKind("custom_test_kind").build(); + AggregationQuery countAggregationQuery = Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregation(Aggregation.count().as("total_count")) + .build(); + + when(this.datastore.runAggregation(argThat(equalsTo(countAggregationQuery)))) + .thenReturn(aggregationResults); assertThat(this.datastoreTemplate.count(TestEntity.class)).isEqualTo(2); } + private ArgumentMatcher equalsTo(AggregationQuery expectedAggregationQuery) { + return actualAggregationQuery -> + expectedAggregationQuery.getAggregations().equals(actualAggregationQuery.getAggregations()) + && expectedAggregationQuery.getNestedStructuredQuery().equals(actualAggregationQuery.getNestedStructuredQuery()); + } + @Test void existsByIdTest() { assertThat(this.datastoreTemplate.existsById(this.key1, TestEntity.class)).isTrue();