From 56d1001005cf7d52f8ba3e5258d6401be729bdbf Mon Sep 17 00:00:00 2001 From: Prateek Date: Tue, 22 Aug 2023 19:06:51 +0530 Subject: [PATCH] feat: Sum and Avg aggregation feature (#1067) * creating sum aggregation * creating avg aggregation * refactoring code to configure alias into an aggregation * Add method in aggregation query builders to accept aggregation in var args form * refactoring equals implementation * changed visibility from public to protected * Made aggregation result capable of returning double values * clirr ignore new method * Made transformer capable of parsing double values * mock webserver simulating sum and avg aggregation response * integration tests for sum and avg * Marking sum and avg methods with @BetaApi annotation * incorporating feedbacks * fixing lint * refactoring dispatcher code * incorporating feedbacks * cleaing up proxy code, as sum/avg now can be run against nightly * remove deprecated annotation * tests for sum and avg aggregations in GQL query * removing the addaggregation variant accepting list of aggregations * fix lint failures * removed BetaApi annotation from sum and avg aggregation * adding doc to AggregationResult#getDouble * adding sum/avg aggreggation test with autogenerated alias * adding a test to run sum and avg aggregation together * testing sum and avg aggregations with transactions * type check before returning the result from aggregation result class --- .../cloud/datastore/AggregationQuery.java | 13 + .../cloud/datastore/AggregationResult.java | 54 ++- .../datastore/aggregation/Aggregation.java | 20 + .../datastore/aggregation/AvgAggregation.java | 86 +++++ .../aggregation/CountAggregation.java | 17 +- .../datastore/aggregation/SumAggregation.java | 86 +++++ .../AggregationQueryResponseTransformer.java | 10 +- .../cloud/datastore/AggregationQueryTest.java | 85 ++++- .../datastore/AggregationResultTest.java | 51 ++- .../google/cloud/datastore/DatastoreTest.java | 2 +- .../google/cloud/datastore/ProtoTestData.java | 4 + .../aggregation/AvgAggregationTest.java | 88 +++++ .../aggregation/SumAggregationTest.java | 88 +++++ ...gregationQueryResponseTransformerTest.java | 52 ++- .../it/ITDatastoreAggregationsTest.java | 361 ++++++++++++++++++ .../cloud/datastore/it/ITDatastoreTest.java | 33 +- 16 files changed, 984 insertions(+), 66 deletions(-) create mode 100644 google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/AvgAggregation.java create mode 100644 google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/SumAggregation.java create mode 100644 google-cloud-datastore/src/test/java/com/google/cloud/datastore/aggregation/AvgAggregationTest.java create mode 100644 google-cloud-datastore/src/test/java/com/google/cloud/datastore/aggregation/SumAggregationTest.java create mode 100644 google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreAggregationsTest.java diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/AggregationQuery.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/AggregationQuery.java index 210889b1e..f657d4aae 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/AggregationQuery.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/AggregationQuery.java @@ -21,6 +21,7 @@ import com.google.cloud.datastore.aggregation.Aggregation; import com.google.cloud.datastore.aggregation.AggregationBuilder; +import java.util.Arrays; import java.util.HashSet; import java.util.Set; @@ -143,6 +144,18 @@ public Builder addAggregation(Aggregation aggregation) { return this; } + public Builder addAggregations(AggregationBuilder... aggregationBuilders) { + for (AggregationBuilder builder : aggregationBuilders) { + this.aggregations.add(builder.build()); + } + return this; + } + + public Builder addAggregations(Aggregation... aggregations) { + this.aggregations.addAll(Arrays.asList(aggregations)); + return this; + } + public Builder over(StructuredQuery nestedQuery) { this.nestedStructuredQuery = nestedQuery; this.mode = STRUCTURED; diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/AggregationResult.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/AggregationResult.java index 6e086c30b..75636c004 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/AggregationResult.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/AggregationResult.java @@ -15,6 +15,9 @@ */ package com.google.cloud.datastore; +import static com.google.cloud.datastore.ValueType.DOUBLE; +import static com.google.cloud.datastore.ValueType.LONG; + import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects.ToStringHelper; import java.util.Map; @@ -24,21 +27,62 @@ /** Represents a result of an {@link AggregationQuery} query submission. */ public class AggregationResult { - private final Map properties; + private final Map> properties; - public AggregationResult(Map properties) { + public AggregationResult(Map> properties) { this.properties = properties; } /** - * Returns a result value for the given alias. + * Returns a result value for the given alias. {@link #getLong(String)} is preferred over this + * method, Use {@link #getLong(String)} wherever possible. * * @param alias A custom alias provided in the query or an autogenerated alias in the form of * 'property_\d' * @return An aggregation result value for the given alias. */ public Long get(String alias) { - return properties.get(alias).get(); + return getLong(alias); + } + + /** + * Returns a result value for the given alias. + * + * @param alias A custom alias provided in the query or an autogenerated alias in the form of + * 'property_\d' + * @return An aggregation result value for the given alias. + */ + public Long getLong(String alias) { + Value value = properties.get(alias); + switch (value.getType()) { + case DOUBLE: + return ((Double) value.get()).longValue(); + case LONG: + return (Long) value.get(); + default: + throw new RuntimeException( + String.format("Unsupported type %s received for alias '%s'.", value.getType(), alias)); + } + } + + /** + * Returns a result value for the given alias. + * + * @param alias A custom alias provided in the query or an autogenerated alias in the form of + * 'property_\d' + * @return An aggregation result value for the given alias. + */ + public Double getDouble(String alias) { + Value value = properties.get(alias); + switch (value.getType()) { + case LONG: + return ((Long) value.get()).doubleValue(); + case DOUBLE: + return (Double) value.get(); + default: + throw new RuntimeException( + String.format("Unsupported type %s received for alias '%s'.", value.getType(), alias)); + } } @Override @@ -61,7 +105,7 @@ public int hashCode() { @Override public String toString() { ToStringHelper toStringHelper = MoreObjects.toStringHelper(this); - for (Entry entry : properties.entrySet()) { + for (Entry> entry : properties.entrySet()) { toStringHelper.add(entry.getKey(), entry.getValue().get()); } return toStringHelper.toString(); diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/Aggregation.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/Aggregation.java index 7bd2bbb38..4baf9e9dc 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/Aggregation.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/Aggregation.java @@ -38,8 +38,28 @@ public String getAlias() { @InternalApi public abstract AggregationQuery.Aggregation toPb(); + @InternalApi + protected AggregationQuery.Aggregation.Builder aggregationBuilder() { + AggregationQuery.Aggregation.Builder aggregationBuilder = + AggregationQuery.Aggregation.newBuilder(); + if (this.getAlias() != null) { + aggregationBuilder.setAlias(this.getAlias()); + } + return aggregationBuilder; + } + /** Returns a {@link CountAggregation} builder. */ public static CountAggregation.Builder count() { return new CountAggregation.Builder(); } + + /** Returns a {@link SumAggregation} builder. */ + public static SumAggregation.Builder sum(String propertyReference) { + return new SumAggregation.Builder().propertyReference(propertyReference); + } + + /** Returns a {@link AvgAggregation} builder. */ + public static AvgAggregation.Builder avg(String propertyReference) { + return new AvgAggregation.Builder().propertyReference(propertyReference); + } } diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/AvgAggregation.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/AvgAggregation.java new file mode 100644 index 000000000..31bd28ffa --- /dev/null +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/AvgAggregation.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore.aggregation; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.api.core.InternalApi; +import com.google.datastore.v1.AggregationQuery; +import com.google.datastore.v1.AggregationQuery.Aggregation.Avg; +import com.google.datastore.v1.PropertyReference; +import java.util.Objects; + +/** Represents an {@link Aggregation} which returns average of numerical values. */ +public class AvgAggregation extends Aggregation { + + private final String propertyReference; + + public AvgAggregation(String alias, String propertyReference) { + super(alias); + checkArgument(propertyReference != null, "Property reference can't be null"); + this.propertyReference = propertyReference; + } + + @InternalApi + @Override + public AggregationQuery.Aggregation toPb() { + PropertyReference reference = + PropertyReference.newBuilder().setName(this.propertyReference).build(); + Avg avg = Avg.newBuilder().setProperty(reference).build(); + return aggregationBuilder().setAvg(avg).build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AvgAggregation that = (AvgAggregation) o; + return Objects.equals(this.propertyReference, that.propertyReference) + && Objects.equals(getAlias(), that.getAlias()); + } + + @Override + public int hashCode() { + return Objects.hash(getAlias(), this.propertyReference); + } + + /** A builder class to create and customize a {@link AvgAggregation}. */ + public static class Builder implements AggregationBuilder { + + private String alias; + private String propertyReference; + + public AvgAggregation.Builder propertyReference(String propertyReference) { + this.propertyReference = propertyReference; + return this; + } + + public AvgAggregation.Builder as(String alias) { + this.alias = alias; + return this; + } + + @Override + public AvgAggregation build() { + return new AvgAggregation(alias, propertyReference); + } + } +} diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/CountAggregation.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/CountAggregation.java index 4f7eb23d6..632b6633d 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/CountAggregation.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/CountAggregation.java @@ -30,14 +30,7 @@ public CountAggregation(String alias) { @Override public AggregationQuery.Aggregation toPb() { - Count.Builder countBuilder = Count.newBuilder(); - - AggregationQuery.Aggregation.Builder aggregationBuilder = - AggregationQuery.Aggregation.newBuilder().setCount(countBuilder); - if (this.getAlias() != null) { - aggregationBuilder.setAlias(this.getAlias()); - } - return aggregationBuilder.build(); + return aggregationBuilder().setCount(Count.newBuilder()).build(); } @Override @@ -49,13 +42,7 @@ public boolean equals(Object o) { return false; } CountAggregation that = (CountAggregation) o; - boolean bothAliasAreNull = getAlias() == null && that.getAlias() == null; - if (bothAliasAreNull) { - return true; - } else { - boolean bothArePresent = getAlias() != null && that.getAlias() != null; - return bothArePresent && getAlias().equals(that.getAlias()); - } + return Objects.equals(getAlias(), that.getAlias()); } @Override diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/SumAggregation.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/SumAggregation.java new file mode 100644 index 000000000..2e1dcd3d5 --- /dev/null +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/aggregation/SumAggregation.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore.aggregation; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.api.core.InternalApi; +import com.google.datastore.v1.AggregationQuery; +import com.google.datastore.v1.AggregationQuery.Aggregation.Sum; +import com.google.datastore.v1.PropertyReference; +import java.util.Objects; + +/** Represents an {@link Aggregation} which returns sum of numerical values. */ +public class SumAggregation extends Aggregation { + + private final String propertyReference; + + public SumAggregation(String alias, String propertyReference) { + super(alias); + checkArgument(propertyReference != null, "Property reference can't be null"); + this.propertyReference = propertyReference; + } + + @InternalApi + @Override + public AggregationQuery.Aggregation toPb() { + PropertyReference reference = + PropertyReference.newBuilder().setName(this.propertyReference).build(); + Sum sum = Sum.newBuilder().setProperty(reference).build(); + return aggregationBuilder().setSum(sum).build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SumAggregation that = (SumAggregation) o; + return Objects.equals(this.propertyReference, that.propertyReference) + && Objects.equals(getAlias(), that.getAlias()); + } + + @Override + public int hashCode() { + return Objects.hash(getAlias(), this.propertyReference); + } + + /** A builder class to create and customize a {@link SumAggregation}. */ + public static class Builder implements AggregationBuilder { + + private String alias; + private String propertyReference; + + public SumAggregation.Builder propertyReference(String propertyReference) { + this.propertyReference = propertyReference; + return this; + } + + public SumAggregation.Builder as(String alias) { + this.alias = alias; + return this; + } + + @Override + public SumAggregation build() { + return new SumAggregation(alias, propertyReference); + } + } +} diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/execution/response/AggregationQueryResponseTransformer.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/execution/response/AggregationQueryResponseTransformer.java index 1515a1147..8c99fcd41 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/execution/response/AggregationQueryResponseTransformer.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/execution/response/AggregationQueryResponseTransformer.java @@ -19,7 +19,6 @@ import com.google.cloud.Timestamp; import com.google.cloud.datastore.AggregationResult; import com.google.cloud.datastore.AggregationResults; -import com.google.cloud.datastore.LongValue; import com.google.datastore.v1.RunAggregationQueryResponse; import com.google.datastore.v1.Value; import java.util.AbstractMap.SimpleEntry; @@ -39,20 +38,19 @@ public AggregationResults transform(RunAggregationQueryResponse response) { Timestamp readTime = Timestamp.fromProto(response.getBatch().getReadTime()); List aggregationResults = response.getBatch().getAggregationResultsList().stream() - .map( - aggregationResult -> new AggregationResult(resultWithLongValues(aggregationResult))) + .map(aggregationResult -> new AggregationResult(transformValues(aggregationResult))) .collect(Collectors.toCollection(LinkedList::new)); return new AggregationResults(aggregationResults, readTime); } - private Map resultWithLongValues( + private Map> transformValues( com.google.datastore.v1.AggregationResult aggregationResult) { return aggregationResult.getAggregatePropertiesMap().entrySet().stream() .map( - (Function, Entry>) + (Function, Entry>>) entry -> new SimpleEntry<>( - entry.getKey(), (LongValue) LongValue.fromPb(entry.getValue()))) + entry.getKey(), com.google.cloud.datastore.Value.fromPb(entry.getValue()))) .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); } } diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/AggregationQueryTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/AggregationQueryTest.java index 840d23bca..fd037808c 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/AggregationQueryTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/AggregationQueryTest.java @@ -18,16 +18,18 @@ import static com.google.cloud.datastore.AggregationQuery.Mode.GQL; import static com.google.cloud.datastore.AggregationQuery.Mode.STRUCTURED; import static com.google.cloud.datastore.StructuredQuery.PropertyFilter.eq; +import static com.google.cloud.datastore.aggregation.Aggregation.avg; import static com.google.cloud.datastore.aggregation.Aggregation.count; +import static com.google.cloud.datastore.aggregation.Aggregation.sum; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; +import com.google.cloud.datastore.aggregation.AvgAggregation; import com.google.cloud.datastore.aggregation.CountAggregation; +import com.google.cloud.datastore.aggregation.SumAggregation; import com.google.common.collect.ImmutableSet; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; public class AggregationQueryTest { @@ -41,8 +43,6 @@ public class AggregationQueryTest { .setLimit(100) .build(); - @Rule public ExpectedException exceptionRule = ExpectedException.none(); - @Test public void testAggregations() { AggregationQuery aggregationQuery = @@ -60,37 +60,92 @@ public void testAggregations() { } @Test - public void testAggregationBuilderWithMoreThanOneAggregations() { + public void testAggregationBuilderWithMultipleAggregationsOneByOne() { AggregationQuery aggregationQuery = Query.newAggregationQueryBuilder() .setNamespace(NAMESPACE) .addAggregation(count().as("total")) - .addAggregation(count().as("new_total")) + .addAggregation(sum("marks").as("total_marks")) + .addAggregation(avg("marks").as("avg_marks")) .over(COMPLETED_TASK_QUERY) .build(); - assertThat(aggregationQuery.getNamespace()).isEqualTo(NAMESPACE); assertThat(aggregationQuery.getAggregations()) - .isEqualTo(ImmutableSet.of(count().as("total").build(), count().as("new_total").build())); - assertThat(aggregationQuery.getNestedStructuredQuery()).isEqualTo(COMPLETED_TASK_QUERY); - assertThat(aggregationQuery.getMode()).isEqualTo(STRUCTURED); + .isEqualTo( + ImmutableSet.of( + count().as("total").build(), + sum("marks").as("total_marks").build(), + avg("marks").as("avg_marks").build())); } @Test - public void testAggregationBuilderWithDuplicateAggregations() { + public void testAggregationBuilderWithMultipleAggregationsTogether() { AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .setNamespace(NAMESPACE) + .addAggregations( + count().as("total"), sum("marks").as("total_marks"), avg("marks").as("avg_marks")) + .over(COMPLETED_TASK_QUERY) + .build(); + + assertThat(aggregationQuery.getAggregations()) + .isEqualTo( + ImmutableSet.of( + count().as("total").build(), + sum("marks").as("total_marks").build(), + avg("marks").as("avg_marks").build())); + } + + @Test + public void testAggregationBuilderWithMultipleAggregationsConfiguredThroughConstructor() { + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .setNamespace(NAMESPACE) + .addAggregations( + new CountAggregation("total"), + new SumAggregation("total_marks", "marks"), + new AvgAggregation("avg_marks", "marks")) + .over(COMPLETED_TASK_QUERY) + .build(); + + assertThat(aggregationQuery.getAggregations()) + .isEqualTo( + ImmutableSet.of( + count().as("total").build(), + sum("marks").as("total_marks").build(), + avg("marks").as("avg_marks").build())); + } + + @Test + public void testAggregationBuilderWithDuplicateAggregations() { + AggregationQuery aggregationQueryWithDuplicateCounts = Query.newAggregationQueryBuilder() .setNamespace(NAMESPACE) .addAggregation(count().as("total")) .addAggregation(count().as("total")) .over(COMPLETED_TASK_QUERY) .build(); + AggregationQuery aggregationQueryWithDuplicateSum = + Query.newAggregationQueryBuilder() + .setNamespace(NAMESPACE) + .addAggregation(sum("marks").as("total")) + .addAggregation(sum("marks").as("total")) + .over(COMPLETED_TASK_QUERY) + .build(); + AggregationQuery aggregationQueryWithDuplicateAvg = + Query.newAggregationQueryBuilder() + .setNamespace(NAMESPACE) + .addAggregation(avg("marks").as("avg_marks")) + .addAggregation(avg("marks").as("avg_marks")) + .over(COMPLETED_TASK_QUERY) + .build(); - assertThat(aggregationQuery.getNamespace()).isEqualTo(NAMESPACE); - assertThat(aggregationQuery.getAggregations()) + assertThat(aggregationQueryWithDuplicateCounts.getAggregations()) .isEqualTo(ImmutableSet.of(count().as("total").build())); - assertThat(aggregationQuery.getNestedStructuredQuery()).isEqualTo(COMPLETED_TASK_QUERY); - assertThat(aggregationQuery.getMode()).isEqualTo(STRUCTURED); + assertThat(aggregationQueryWithDuplicateSum.getAggregations()) + .isEqualTo(ImmutableSet.of(sum("marks").as("total").build())); + assertThat(aggregationQueryWithDuplicateAvg.getAggregations()) + .isEqualTo(ImmutableSet.of(avg("marks").as("avg_marks").build())); } @Test diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/AggregationResultTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/AggregationResultTest.java index 06a5cb5f7..592bfb368 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/AggregationResultTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/AggregationResultTest.java @@ -16,6 +16,7 @@ package com.google.cloud.datastore; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableMap; import org.junit.Test; @@ -23,14 +24,58 @@ public class AggregationResultTest { @Test - public void shouldGetAggregationResultValueByAlias() { + public void shouldGetLongAggregatedResultValueByAlias() { AggregationResult aggregationResult = new AggregationResult( ImmutableMap.of( "count", LongValue.of(45), "property_2", LongValue.of(30))); - assertThat(aggregationResult.get("count")).isEqualTo(45L); - assertThat(aggregationResult.get("property_2")).isEqualTo(30L); + assertThat(aggregationResult.getLong("count")).isEqualTo(45L); + assertThat(aggregationResult.getLong("property_2")).isEqualTo(30L); + } + + @Test + public void shouldGetDoubleAggregatedResultValueByAlias() { + AggregationResult aggregationResult = + new AggregationResult( + ImmutableMap.of( + "qty_avg", DoubleValue.of(45.9322), + "qty_sum", DoubleValue.of(783.2134))); + + assertThat(aggregationResult.getDouble("qty_avg")).isEqualTo(45.9322); + assertThat(aggregationResult.getDouble("qty_sum")).isEqualTo(783.2134); + } + + @Test + public void shouldGetLongAggregatedResultValueAsDouble() { + AggregationResult aggregationResult = + new AggregationResult(ImmutableMap.of("count", LongValue.of(45))); + + assertThat(aggregationResult.getDouble("count")).isEqualTo(45D); + } + + @Test + public void shouldGetDoubleAggregatedResultValueAsLong() { + AggregationResult aggregationResult = + new AggregationResult(ImmutableMap.of("qty_avg", DoubleValue.of(45.9322))); + + assertThat(aggregationResult.getLong("qty_avg")).isEqualTo(45L); + } + + @Test + public void shouldThrowRuntimeExceptionOnUnknownTypes() { + AggregationResult aggregationResult = + new AggregationResult( + ImmutableMap.of( + "qty_avg", BooleanValue.of(true))); // only double and long types are supported + + RuntimeException e1 = + assertThrows(RuntimeException.class, () -> aggregationResult.getLong("qty_avg")); + assertThat(e1.getMessage()).isEqualTo("Unsupported type BOOLEAN received for alias 'qty_avg'."); + + RuntimeException e2 = + assertThrows(RuntimeException.class, () -> aggregationResult.getDouble("qty_avg")); + assertThat(e2.getMessage()).isEqualTo("Unsupported type BOOLEAN received for alias 'qty_avg'."); } } diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/DatastoreTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/DatastoreTest.java index d0f00d79b..cd768f986 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/DatastoreTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/DatastoreTest.java @@ -555,7 +555,7 @@ public void testRunAggregationQuery() { .build(); AggregationResult result = getOnlyElement(mockDatastore.runAggregation(getCountQuery)); - assertThat(result.get("total_count")).isEqualTo(209L); + assertThat(result.getLong("total_count")).isEqualTo(209L); EasyMock.verify(rpcFactoryMock, rpcMock); } diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java index 25b902fd4..8e2ba890a 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java @@ -42,6 +42,10 @@ public static Value intValue(long value) { return Value.newBuilder().setIntegerValue(value).build(); } + public static Value doubleValue(double value) { + return Value.newBuilder().setDoubleValue(value).build(); + } + public static GqlQueryParameter gqlQueryParameter(Value value) { return GqlQueryParameter.newBuilder().setValue(value).build(); } diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/aggregation/AvgAggregationTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/aggregation/AvgAggregationTest.java new file mode 100644 index 000000000..ac785c8b3 --- /dev/null +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/aggregation/AvgAggregationTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore.aggregation; + +import static com.google.cloud.datastore.aggregation.Aggregation.avg; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.datastore.v1.AggregationQuery; +import org.junit.Test; + +public class AvgAggregationTest { + + @Test + public void shouldThrowExceptionWhenPropertyReferenceIsNull() { + assertThrows(IllegalArgumentException.class, () -> avg(null).build()); + } + + @Test + public void testAvgAggregationWithDefaultValues() { + AggregationQuery.Aggregation avgAggregation = avg("marks").build().toPb(); + + assertThat(avgAggregation.getAvg().getProperty().getName()).isEqualTo("marks"); + assertThat(avgAggregation.getAlias()).isEqualTo(""); + } + + @Test + public void testCountAggregationWithAlias() { + AggregationQuery.Aggregation avgAggregation = avg("marks").as("total_marks").build().toPb(); + + assertThat(avgAggregation.getAvg().getProperty().getName()).isEqualTo("marks"); + assertThat(avgAggregation.getAlias()).isEqualTo("total_marks"); + } + + @Test + public void testEqualsWithAliasVariations() { + AvgAggregation.Builder aggregationWithAlias1 = avg("marks").as("total"); + AvgAggregation.Builder aggregationWithAlias2 = avg("marks").as("total"); + AvgAggregation.Builder aggregationWithoutAlias1 = avg("marks"); + AvgAggregation.Builder aggregationWithoutAlias2 = avg("marks"); + + // same aliases + assertThat(aggregationWithAlias1.build()).isEqualTo(aggregationWithAlias2.build()); + assertThat(aggregationWithAlias2.build()).isEqualTo(aggregationWithAlias1.build()); + + // with and without aliases + assertThat(aggregationWithAlias1.build()).isNotEqualTo(aggregationWithoutAlias1.build()); + assertThat(aggregationWithoutAlias1.build()).isNotEqualTo(aggregationWithAlias1.build()); + + // no aliases + assertThat(aggregationWithoutAlias1.build()).isEqualTo(aggregationWithoutAlias2.build()); + assertThat(aggregationWithoutAlias2.build()).isEqualTo(aggregationWithoutAlias1.build()); + + // different aliases + assertThat(aggregationWithAlias1.as("new-alias").build()) + .isNotEqualTo(aggregationWithAlias2.build()); + assertThat(aggregationWithAlias2.build()) + .isNotEqualTo(aggregationWithAlias1.as("new-alias").build()); + } + + @Test + public void testEqualsWithPropertyReferenceVariations() { + AvgAggregation totalMarks1 = avg("marks").build(); + AvgAggregation totalMarks2 = avg("marks").build(); + + AvgAggregation totalQuantities = avg("quantity").build(); + + assertThat(totalMarks1).isEqualTo(totalMarks2); + assertThat(totalMarks2).isEqualTo(totalMarks1); + + assertThat(totalMarks1).isNotEqualTo(totalQuantities); + assertThat(totalQuantities).isNotEqualTo(totalMarks1); + } +} diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/aggregation/SumAggregationTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/aggregation/SumAggregationTest.java new file mode 100644 index 000000000..e4f637af0 --- /dev/null +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/aggregation/SumAggregationTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore.aggregation; + +import static com.google.cloud.datastore.aggregation.Aggregation.sum; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.datastore.v1.AggregationQuery; +import org.junit.Test; + +public class SumAggregationTest { + + @Test + public void shouldThrowExceptionWhenPropertyReferenceIsNull() { + assertThrows(IllegalArgumentException.class, () -> sum(null).build()); + } + + @Test + public void testSumAggregationWithDefaultValues() { + AggregationQuery.Aggregation sumAggregation = sum("marks").build().toPb(); + + assertThat(sumAggregation.getSum().getProperty().getName()).isEqualTo("marks"); + assertThat(sumAggregation.getAlias()).isEqualTo(""); + } + + @Test + public void testCountAggregationWithAlias() { + AggregationQuery.Aggregation sumAggregation = sum("marks").as("total_marks").build().toPb(); + + assertThat(sumAggregation.getSum().getProperty().getName()).isEqualTo("marks"); + assertThat(sumAggregation.getAlias()).isEqualTo("total_marks"); + } + + @Test + public void testEqualsWithAliasVariations() { + SumAggregation.Builder aggregationWithAlias1 = sum("marks").as("total"); + SumAggregation.Builder aggregationWithAlias2 = sum("marks").as("total"); + SumAggregation.Builder aggregationWithoutAlias1 = sum("marks"); + SumAggregation.Builder aggregationWithoutAlias2 = sum("marks"); + + // same aliases + assertThat(aggregationWithAlias1.build()).isEqualTo(aggregationWithAlias2.build()); + assertThat(aggregationWithAlias2.build()).isEqualTo(aggregationWithAlias1.build()); + + // with and without aliases + assertThat(aggregationWithAlias1.build()).isNotEqualTo(aggregationWithoutAlias1.build()); + assertThat(aggregationWithoutAlias1.build()).isNotEqualTo(aggregationWithAlias1.build()); + + // no aliases + assertThat(aggregationWithoutAlias1.build()).isEqualTo(aggregationWithoutAlias2.build()); + assertThat(aggregationWithoutAlias2.build()).isEqualTo(aggregationWithoutAlias1.build()); + + // different aliases + assertThat(aggregationWithAlias1.as("new-alias").build()) + .isNotEqualTo(aggregationWithAlias2.build()); + assertThat(aggregationWithAlias2.build()) + .isNotEqualTo(aggregationWithAlias1.as("new-alias").build()); + } + + @Test + public void testEqualsWithPropertyReferenceVariations() { + SumAggregation totalMarks1 = sum("marks").build(); + SumAggregation totalMarks2 = sum("marks").build(); + + SumAggregation totalQuantities = sum("quantity").build(); + + assertThat(totalMarks1).isEqualTo(totalMarks2); + assertThat(totalMarks2).isEqualTo(totalMarks1); + + assertThat(totalMarks1).isNotEqualTo(totalQuantities); + assertThat(totalQuantities).isNotEqualTo(totalMarks1); + } +} diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/execution/response/AggregationQueryResponseTransformerTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/execution/response/AggregationQueryResponseTransformerTest.java index 8776d4221..7ba57223f 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/execution/response/AggregationQueryResponseTransformerTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/execution/response/AggregationQueryResponseTransformerTest.java @@ -15,13 +15,13 @@ */ package com.google.cloud.datastore.execution.response; +import static com.google.cloud.datastore.ProtoTestData.doubleValue; import static com.google.cloud.datastore.ProtoTestData.intValue; import static com.google.common.truth.Truth.assertThat; import com.google.cloud.Timestamp; import com.google.cloud.datastore.AggregationResult; import com.google.cloud.datastore.AggregationResults; -import com.google.cloud.datastore.LongValue; import com.google.common.collect.ImmutableMap; import com.google.datastore.v1.AggregationResultBatch; import com.google.datastore.v1.RunAggregationQueryResponse; @@ -40,7 +40,7 @@ public class AggregationQueryResponseTransformerTest { new AggregationQueryResponseTransformer(); @Test - public void shouldTransformAggregationQueryResponse() { + public void shouldTransformAggregationQueryResponseWithIntValues() { Map result1 = new HashMap<>( ImmutableMap.of( @@ -51,7 +51,46 @@ public void shouldTransformAggregationQueryResponse() { new HashMap<>( ImmutableMap.of( "count", intValue(509), - "property_2", intValue(100))); + "property_2", intValue((100)))); + Timestamp readTime = Timestamp.now(); + + AggregationResultBatch resultBatch = + AggregationResultBatch.newBuilder() + .addAggregationResults( + com.google.datastore.v1.AggregationResult.newBuilder() + .putAllAggregateProperties(result1) + .build()) + .addAggregationResults( + com.google.datastore.v1.AggregationResult.newBuilder() + .putAllAggregateProperties(result2) + .build()) + .setReadTime(readTime.toProto()) + .build(); + RunAggregationQueryResponse runAggregationQueryResponse = + RunAggregationQueryResponse.newBuilder().setBatch(resultBatch).build(); + + AggregationResults aggregationResults = + responseTransformer.transform(runAggregationQueryResponse); + + assertThat(aggregationResults.size()).isEqualTo(2); + assertThat(aggregationResults.get(0)).isEqualTo(new AggregationResult(toDomainValues(result1))); + assertThat(aggregationResults.get(1)).isEqualTo(new AggregationResult(toDomainValues(result2))); + assertThat(aggregationResults.getReadTime()).isEqualTo(readTime); + } + + @Test + public void shouldTransformAggregationQueryResponseWithDoubleValues() { + Map result1 = + new HashMap<>( + ImmutableMap.of( + "count", doubleValue(209.678), + "property_2", doubleValue(100.678))); + + Map result2 = + new HashMap<>( + ImmutableMap.of( + "count", doubleValue(509.678), + "property_2", doubleValue((100.678)))); Timestamp readTime = Timestamp.now(); AggregationResultBatch resultBatch = @@ -78,14 +117,15 @@ public void shouldTransformAggregationQueryResponse() { assertThat(aggregationResults.getReadTime()).isEqualTo(readTime); } - private Map toDomainValues(Map map) { + private Map> toDomainValues( + Map map) { return map.entrySet().stream() .map( - (Function, Entry>) + (Function, Entry>>) entry -> new SimpleEntry<>( - entry.getKey(), (LongValue) LongValue.fromPb(entry.getValue()))) + entry.getKey(), com.google.cloud.datastore.Value.fromPb(entry.getValue()))) .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); } } diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreAggregationsTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreAggregationsTest.java new file mode 100644 index 000000000..fd430095f --- /dev/null +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreAggregationsTest.java @@ -0,0 +1,361 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore.it; + +import static com.google.cloud.datastore.aggregation.Aggregation.avg; +import static com.google.cloud.datastore.aggregation.Aggregation.count; +import static com.google.cloud.datastore.aggregation.Aggregation.sum; +import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.truth.Truth.assertThat; + +import com.google.cloud.datastore.AggregationQuery; +import com.google.cloud.datastore.AggregationResult; +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.Datastore.TransactionCallable; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.Entity; +import com.google.cloud.datastore.EntityQuery; +import com.google.cloud.datastore.GqlQuery; +import com.google.cloud.datastore.Key; +import com.google.cloud.datastore.Query; +import com.google.cloud.datastore.QueryResults; +import com.google.cloud.datastore.Transaction; +import com.google.cloud.datastore.testing.RemoteDatastoreHelper; +import com.google.common.collect.ImmutableList; +import com.google.datastore.v1.TransactionOptions; +import com.google.datastore.v1.TransactionOptions.ReadOnly; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import org.junit.After; +import org.junit.Test; + +// TODO(jainsahab) Move all the aggregation related tests from ITDatastoreTest to this file +public class ITDatastoreAggregationsTest { + + private static final RemoteDatastoreHelper HELPER = RemoteDatastoreHelper.create(); + private static final DatastoreOptions OPTIONS = HELPER.getOptions(); + private static final Datastore DATASTORE = OPTIONS.getService(); + + private static final String KIND = "Marks"; + + @After + public void tearDown() { + EntityQuery allEntitiesQuery = Query.newEntityQueryBuilder().build(); + QueryResults allEntities = DATASTORE.run(allEntitiesQuery); + Key[] keysToDelete = + ImmutableList.copyOf(allEntities).stream().map(Entity::getKey).toArray(Key[]::new); + DATASTORE.delete(keysToDelete); + } + + Key key1 = DATASTORE.newKeyFactory().setKind(KIND).newKey(1); + Key key2 = DATASTORE.newKeyFactory().setKind(KIND).newKey(2); + Key key3 = DATASTORE.newKeyFactory().setKind(KIND).newKey(3); + + Entity entity1 = + Entity.newBuilder(key1).set("name", "Jon Stark").set("marks", 89).set("cgpa", 7.34).build(); + Entity entity2 = + Entity.newBuilder(key2).set("name", "Arya Stark").set("marks", 95).set("cgpa", 9.27).build(); + Entity entity3 = + Entity.newBuilder(key3).set("name", "Night king").set("marks", 55).set("cgpa", 5.16).build(); + + @Test + public void testSumAggregation() { + DATASTORE.put(entity1, entity2); + + EntityQuery baseQuery = Query.newEntityQueryBuilder().setKind(KIND).build(); + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregations(sum("marks").as("total_marks")) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // sum of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("total_marks")) + .isEqualTo(184L); + + // sum of 3 entities + DATASTORE.put(entity3); + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("total_marks")) + .isEqualTo(239L); + } + + @Test + public void testSumAggregationWithAutoGeneratedAlias() { + DATASTORE.put(entity1, entity2); + + EntityQuery baseQuery = Query.newEntityQueryBuilder().setKind(KIND).build(); + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregations(sum("marks")) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // sum of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("property_1")) + .isEqualTo(184L); + + // sum of 3 entities + DATASTORE.put(entity3); + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("property_1")) + .isEqualTo(239L); + } + + @Test + public void testSumAggregationInGqlQuery() { + DATASTORE.put(entity1, entity2); + + GqlQuery gqlQuery = + GqlQuery.newGqlQueryBuilder( + "AGGREGATE SUM(marks) AS total_marks OVER (SELECT * FROM Marks)") + .build(); + + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(gqlQuery) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // sum of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("total_marks")) + .isEqualTo(184L); + + // sum of 3 entities + DATASTORE.put(entity3); + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("total_marks")) + .isEqualTo(239L); + } + + @Test + public void testSumAggregationWithResultOfDoubleType() { + DATASTORE.put(entity1, entity2); + + EntityQuery baseQuery = Query.newEntityQueryBuilder().setKind(KIND).build(); + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregations(sum("cgpa").as("total_cgpa")) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // sum of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("total_cgpa")) + .isEqualTo(16.61); + + // sum of 3 entities + DATASTORE.put(entity3); + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("total_cgpa")) + .isEqualTo(21.77); + } + + @Test + public void testAvgAggregation() { + DATASTORE.put(entity1, entity2); + + EntityQuery baseQuery = Query.newEntityQueryBuilder().setKind(KIND).build(); + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregations(avg("marks").as("avg_marks")) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // avg of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("avg_marks")) + .isEqualTo(92D); + + // avg of 3 entities + DATASTORE.put(entity3); + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("avg_marks")) + .isEqualTo(79.66666666666667); + } + + @Test + public void testAvgAggregationWithAutoGeneratedAlias() { + DATASTORE.put(entity1, entity2); + + EntityQuery baseQuery = Query.newEntityQueryBuilder().setKind(KIND).build(); + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregations(avg("marks")) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // avg of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("property_1")) + .isEqualTo(92D); + + // avg of 3 entities + DATASTORE.put(entity3); + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("property_1")) + .isEqualTo(79.66666666666667); + } + + @Test + public void testAvgAggregationInGqlQuery() { + DATASTORE.put(entity1, entity2); + + GqlQuery gqlQuery = + Query.newGqlQueryBuilder("AGGREGATE AVG(marks) AS avg_marks OVER (SELECT * FROM Marks)") + .build(); + + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(gqlQuery) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // avg of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("avg_marks")) + .isEqualTo(92D); + + // avg of 3 entities + DATASTORE.put(entity3); + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("avg_marks")) + .isEqualTo(79.66666666666667); + } + + @Test + public void testSumAndAvgAggregationTogether() { + DATASTORE.put(entity1, entity2); + + EntityQuery baseQuery = Query.newEntityQueryBuilder().setKind(KIND).build(); + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregations(sum("marks").as("total_marks")) + .addAggregations(avg("marks").as("avg_marks")) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // sum of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("total_marks")) + .isEqualTo(184L); + // avg of 2 entities + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getDouble("avg_marks")) + .isEqualTo(92D); + } + + @Test + public void testTransactionShouldReturnAConsistentSnapshot() { + DATASTORE.put(entity1, entity2); + + EntityQuery baseQuery = Query.newEntityQueryBuilder().setKind(KIND).build(); + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregation(count().as("count")) + .addAggregations(sum("marks").as("total_marks")) + .addAggregations(avg("marks").as("avg_marks")) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + // original entity count is 2 + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("count")) + .isEqualTo(2L); + + // FIRST TRANSACTION + DATASTORE.runInTransaction( + (TransactionCallable) + inFirstTransaction -> { + // creating a new entity + inFirstTransaction.put(entity3); + + // aggregation result consistently being produced for original 2 entities + AggregationResult aggregationResult = + getOnlyElement(inFirstTransaction.runAggregation(aggregationQuery)); + assertThat(aggregationResult.getLong("count")).isEqualTo(2L); + assertThat(aggregationResult.getLong("total_marks")).isEqualTo(184L); + assertThat(aggregationResult.getDouble("avg_marks")).isEqualTo(92D); + return null; + }); + + // after first transaction is committed, we have 3 entities now. + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("count")) + .isEqualTo(3L); + + // SECOND TRANSACTION + DATASTORE.runInTransaction( + (TransactionCallable) + inSecondTransaction -> { + // deleting ENTITY3 + inSecondTransaction.delete(entity3.getKey()); + + // aggregation result still coming for 3 entities + AggregationResult aggregationResult = + getOnlyElement(inSecondTransaction.runAggregation(aggregationQuery)); + assertThat(aggregationResult.getLong("count")).isEqualTo(3L); + assertThat(aggregationResult.getLong("total_marks")).isEqualTo(239L); + assertThat(aggregationResult.getDouble("avg_marks")).isEqualTo(79.66666666666667); + return null; + }); + + // after second transaction is committed, we are back to 2 entities now. + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("count")) + .isEqualTo(2L); + } + + @Test + public void testReadOnlyTransactionShouldNotLockTheDocuments() + throws ExecutionException, InterruptedException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + DATASTORE.put(entity1, entity2); + + EntityQuery baseQuery = Query.newEntityQueryBuilder().setKind(KIND).build(); + AggregationQuery aggregationQuery = + Query.newAggregationQueryBuilder() + .over(baseQuery) + .addAggregation(count().as("count")) + .addAggregations(sum("marks").as("total_marks")) + .addAggregations(avg("marks").as("avg_marks")) + .setNamespace(OPTIONS.getNamespace()) + .build(); + + TransactionOptions transactionOptions = + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build(); + Transaction readOnlyTransaction = DATASTORE.newTransaction(transactionOptions); + + // Executing query in transaction, results for original 2 entities + AggregationResult aggregationResult = + getOnlyElement(readOnlyTransaction.runAggregation(aggregationQuery)); + assertThat(aggregationResult.getLong("count")).isEqualTo(2L); + assertThat(aggregationResult.getLong("total_marks")).isEqualTo(184L); + assertThat(aggregationResult.getDouble("avg_marks")).isEqualTo(92D); + + // Concurrent write task. + Future addNewEntityTaskOutsideTransaction = + executor.submit( + () -> { + DATASTORE.put(entity3); + return null; + }); + + // should not throw exception and complete successfully as the ongoing transaction is read-only. + addNewEntityTaskOutsideTransaction.get(); + + // cleanup + readOnlyTransaction.commit(); + executor.shutdownNow(); + + assertThat(getOnlyElement(DATASTORE.runAggregation(aggregationQuery)).getLong("count")) + .isEqualTo(3L); + } +} diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreTest.java index 167d2898c..7c68ffe32 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreTest.java @@ -821,7 +821,7 @@ public void testRunAggregationQueryInTransactionShouldReturnAConsistentSnapshot( .build(); // original entity count is 2 - assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).get("count")) + assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).getLong("count")) .isEqualTo(2L); // FIRST TRANSACTION @@ -836,14 +836,15 @@ public void testRunAggregationQueryInTransactionShouldReturnAConsistentSnapshot( // count remains 2 assertThat( getOnlyElement(inFirstTransaction.runAggregation(aggregationQuery)) - .get("count")) + .getLong("count")) .isEqualTo(2L); - assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).get("count")) + assertThat( + getOnlyElement(datastore.runAggregation(aggregationQuery)).getLong("count")) .isEqualTo(2L); return null; }); // after first transaction is committed, count is updated to 3 now. - assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).get("count")) + assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).getLong("count")) .isEqualTo(3L); // SECOND TRANSACTION @@ -856,14 +857,15 @@ public void testRunAggregationQueryInTransactionShouldReturnAConsistentSnapshot( // count remains 3 assertThat( getOnlyElement(inSecondTransaction.runAggregation(aggregationQuery)) - .get("count")) + .getLong("count")) .isEqualTo(3L); - assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).get("count")) + assertThat( + getOnlyElement(datastore.runAggregation(aggregationQuery)).getLong("count")) .isEqualTo(3L); return null; }); // after second transaction is committed, count is updated to 2 now. - assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).get("count")) + assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).getLong("count")) .isEqualTo(2L); datastore.delete(newEntityKey); } @@ -889,7 +891,8 @@ public void testRunAggregationQueryInAReadOnlyTransactionShouldNotLockTheCounted Transaction readOnlyTransaction = datastore.newTransaction(transactionOptions); // Executing query in transaction - assertThat(getOnlyElement(readOnlyTransaction.runAggregation(aggregationQuery)).get("count")) + assertThat( + getOnlyElement(readOnlyTransaction.runAggregation(aggregationQuery)).getLong("count")) .isEqualTo(2L); // Concurrent write task. @@ -912,7 +915,7 @@ public void testRunAggregationQueryInAReadOnlyTransactionShouldNotLockTheCounted readOnlyTransaction.commit(); executor.shutdownNow(); - assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).get("count")) + assertThat(getOnlyElement(datastore.runAggregation(aggregationQuery)).getLong("count")) .isEqualTo(3L); } @@ -1561,7 +1564,7 @@ private void testCountAggregationWith(Consumer configu AggregationQuery aggregationQuery = builder.build(); String alias = "total_count"; - Long countBeforeAdd = getOnlyElement(datastore.runAggregation(aggregationQuery)).get(alias); + Long countBeforeAdd = getOnlyElement(datastore.runAggregation(aggregationQuery)).getLong(alias); long expectedCount = countBeforeAdd + 1; Entity newEntity = @@ -1573,7 +1576,7 @@ private void testCountAggregationWith(Consumer configu .build(); datastore.put(newEntity); - Long countAfterAdd = getOnlyElement(datastore.runAggregation(aggregationQuery)).get(alias); + Long countAfterAdd = getOnlyElement(datastore.runAggregation(aggregationQuery)).getLong(alias); assertThat(countAfterAdd).isEqualTo(expectedCount); datastore.delete(newEntity.getKey()); @@ -1589,7 +1592,7 @@ private void testCountAggregationWithLimit( withoutLimitConfigurer.accept(withoutLimitBuilder); Long currentCount = - getOnlyElement(datastore.runAggregation(withoutLimitBuilder.build())).get(alias); + getOnlyElement(datastore.runAggregation(withoutLimitBuilder.build())).getLong(alias); long limit = currentCount - 1; AggregationQuery.Builder withLimitBuilder = @@ -1597,7 +1600,7 @@ private void testCountAggregationWithLimit( withLimitConfigurer.accept(withLimitBuilder, limit); Long countWithLimit = - getOnlyElement(datastore.runAggregation(withLimitBuilder.build())).get(alias); + getOnlyElement(datastore.runAggregation(withLimitBuilder.build())).getLong(alias); assertThat(countWithLimit).isEqualTo(limit); } @@ -1640,12 +1643,12 @@ private void testCountAggregationReadTimeWith(Consumer AggregationQuery countAggregationQuery = builder.build(); Long latestCount = - getOnlyElement(datastore.runAggregation(countAggregationQuery)).get("total_count"); + getOnlyElement(datastore.runAggregation(countAggregationQuery)).getLong("total_count"); assertThat(latestCount).isEqualTo(3L); Long oldCount = getOnlyElement(datastore.runAggregation(countAggregationQuery, ReadOption.readTime(now))) - .get("total_count"); + .getLong("total_count"); assertThat(oldCount).isEqualTo(2L); } finally { datastore.delete(entity1.getKey(), entity2.getKey(), entity3.getKey());