Skip to content

Commit

Permalink
fix: Implementing count with aggregation query (GoogleCloudPlatform#1782
Browse files Browse the repository at this point in the history
)

Modifying the implementation of DatastoreTemplate#count to use recently introduced [COUNT aggregation and Aggregation queries in datastore](https://cloud.google.com/datastore/docs/aggregation-queries).

Fixes. GoogleCloudPlatform#1781
  • Loading branch information
its-snorlax authored May 8, 2023
1 parent b134c92 commit 1ee2244
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package com.google.cloud.spring.data.datastore.core;

import com.google.cloud.datastore.AggregationQuery;
import com.google.cloud.datastore.AggregationResult;
import com.google.cloud.datastore.AggregationResults;
import com.google.cloud.datastore.BaseEntity;
import com.google.cloud.datastore.BaseKey;
import com.google.cloud.datastore.Cursor;
Expand All @@ -39,6 +42,7 @@
import com.google.cloud.datastore.StructuredQuery.Filter;
import com.google.cloud.datastore.StructuredQuery.PropertyFilter;
import com.google.cloud.datastore.Value;
import com.google.cloud.datastore.aggregation.Aggregation;
import com.google.cloud.spring.data.datastore.core.convert.DatastoreEntityConverter;
import com.google.cloud.spring.data.datastore.core.convert.ObjectToKeyFactory;
import com.google.cloud.spring.data.datastore.core.mapping.DatastoreDataException;
Expand All @@ -55,6 +59,7 @@
import com.google.cloud.spring.data.datastore.core.util.SliceUtil;
import com.google.cloud.spring.data.datastore.core.util.ValueUtil;
import com.google.cloud.spring.data.datastore.repository.query.DatastorePageable;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -247,7 +252,19 @@ private void performDelete(Key[] keys, Iterable ids, Iterable entities, Class en

@Override
public long count(Class<?> entityClass) {
return findAllKeys(entityClass).length;
KeyQuery baseQuery = Query.newKeyQueryBuilder()
.setKind(getPersistentEntity(entityClass).kindName())
.build();

AggregationQuery countAggregationQuery = Query.newAggregationQueryBuilder()
.over(baseQuery)
.addAggregation(Aggregation.count().as("total_count"))
.build();

AggregationResults aggregationResults = getDatastoreReadWriter().runAggregation(countAggregationQuery);
maybeEmitEvent(new AfterQueryEvent(aggregationResults, countAggregationQuery));
AggregationResult aggregationResult = Iterables.getOnlyElement(aggregationResults);
return aggregationResult.get("total_count");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.ArgumentMatchers.same;
Expand All @@ -28,6 +29,9 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.cloud.datastore.AggregationQuery;
import com.google.cloud.datastore.AggregationResult;
import com.google.cloud.datastore.AggregationResults;
import com.google.cloud.datastore.Cursor;
import com.google.cloud.datastore.Datastore;
import com.google.cloud.datastore.Datastore.TransactionCallable;
Expand All @@ -49,6 +53,7 @@
import com.google.cloud.datastore.QueryResults;
import com.google.cloud.datastore.StructuredQuery;
import com.google.cloud.datastore.StructuredQuery.PropertyFilter;
import com.google.cloud.datastore.aggregation.Aggregation;
import com.google.cloud.spring.core.util.MapBuilder;
import com.google.cloud.spring.data.datastore.core.convert.DatastoreEntityConverter;
import com.google.cloud.spring.data.datastore.core.convert.ObjectToKeyFactory;
Expand Down Expand Up @@ -85,6 +90,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.ArgumentMatcher;
import org.mockito.ArgumentMatchers;
import org.mockito.InOrder;
import org.mockito.Mockito;
Expand Down Expand Up @@ -1043,22 +1049,28 @@ private boolean nextPageTest(boolean hasNextPage) {

@Test
void countTest() {
QueryResults<Key> queryResults = mock(QueryResults.class);
when(queryResults.getResultClass()).thenReturn((Class) Key.class);
doAnswer(
invocation -> {
Arrays.asList(this.key1, this.key2)
.iterator()
.forEachRemaining(invocation.getArgument(0));
return null;
})
.when(queryResults)
.forEachRemaining(any());
when(this.datastore.run(Query.newKeyQueryBuilder().setKind("custom_test_kind").build()))
.thenReturn(queryResults);
AggregationResult aggregationResult = mock(AggregationResult.class);
AggregationResults aggregationResults = mock(AggregationResults.class);
when(aggregationResult.get("total_count")).thenReturn(2L);
when(aggregationResults.iterator()).thenReturn(List.of(aggregationResult).iterator());

KeyQuery baseQuery = Query.newKeyQueryBuilder().setKind("custom_test_kind").build();
AggregationQuery countAggregationQuery = Query.newAggregationQueryBuilder()
.over(baseQuery)
.addAggregation(Aggregation.count().as("total_count"))
.build();

when(this.datastore.runAggregation(argThat(equalsTo(countAggregationQuery))))
.thenReturn(aggregationResults);
assertThat(this.datastoreTemplate.count(TestEntity.class)).isEqualTo(2);
}

private ArgumentMatcher<AggregationQuery> equalsTo(AggregationQuery expectedAggregationQuery) {
return actualAggregationQuery ->
expectedAggregationQuery.getAggregations().equals(actualAggregationQuery.getAggregations())
&& expectedAggregationQuery.getNestedStructuredQuery().equals(actualAggregationQuery.getNestedStructuredQuery());
}

@Test
void existsByIdTest() {
assertThat(this.datastoreTemplate.existsById(this.key1, TestEntity.class)).isTrue();
Expand Down

0 comments on commit 1ee2244

Please sign in to comment.