From 6fc57e47c86566f600c56e96d0f2e1721baa6af6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 9 Feb 2024 16:01:09 +0100 Subject: [PATCH] feat: support lazy decoding of query results (#2847) * feat: support lazy decoding of query results Adds an option for lazy decoding of query results. Currently, all values in a query result row are decoded from protobuf values to plain Java objects at the moment that the result set is advanced to the next row. This means that all values are decoded, regardless whether the application actually fetches these or not. Lazy decoding also enables the possibility for (internal) consumers of a result set to access the protobuf value before it is converted to a plain Java object. This for example allows ChecksumResultSet to calculate the checksum based on the protobuf value, instead of a Java object, which can be more efficient. * fix: add null check * perf: calculate checksum using protobuf values (#2848) * perf: calculate checksum using protobuf values * chore: cleanup * test: remove unrelated test * fix: undo change to public API * chore: cleanup| --- .../cloud/spanner/AbstractReadContext.java | 15 +- .../google/cloud/spanner/BatchClientImpl.java | 2 + .../com/google/cloud/spanner/DecodeMode.java | 35 ++ .../cloud/spanner/ForwardingResultSet.java | 18 +- .../google/cloud/spanner/GrpcResultSet.java | 24 +- .../com/google/cloud/spanner/GrpcStruct.java | 132 +++++- .../com/google/cloud/spanner/Options.java | 32 ++ .../cloud/spanner/ProtobufResultSet.java | 42 ++ .../com/google/cloud/spanner/ResultSets.java | 15 +- .../com/google/cloud/spanner/SessionImpl.java | 4 + .../com/google/cloud/spanner/SpannerImpl.java | 4 + .../google/cloud/spanner/SpannerOptions.java | 20 + .../java/com/google/cloud/spanner/Value.java | 4 + .../spanner/connection/ChecksumResultSet.java | 375 +++++++-------- .../connection/DirectExecuteResultSet.java | 18 +- .../connection/ReadWriteTransaction.java | 7 +- .../ReplaceableForwardingResultSet.java | 22 +- .../cloud/spanner/connection/SpannerPool.java | 4 + .../spanner/connection/DecodeModeTest.java | 128 ++++++ .../DirectExecuteResultSetTest.java | 3 + .../connection/RandomResultSetGenerator.java | 6 +- .../connection/ReadWriteTransactionTest.java | 435 +++++++++--------- .../ReplaceableForwardingResultSetTest.java | 11 +- 23 files changed, 910 insertions(+), 446 deletions(-) create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 2505b8b4985..ae18a1e4f5f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -70,6 +70,7 @@ abstract static class Builder, T extends AbstractReadCon private TraceWrapper tracer; private int defaultPrefetchChunks = SpannerOptions.Builder.DEFAULT_PREFETCH_CHUNKS; private QueryOptions defaultQueryOptions = SpannerOptions.Builder.DEFAULT_QUERY_OPTIONS; + private DecodeMode defaultDecodeMode = SpannerOptions.Builder.DEFAULT_DECODE_MODE; private DirectedReadOptions defaultDirectedReadOption; private ExecutorProvider executorProvider; private Clock clock = new Clock(); @@ -111,6 +112,11 @@ B setDefaultQueryOptions(QueryOptions defaultQueryOptions) { return self(); } + B setDefaultDecodeMode(DecodeMode defaultDecodeMode) { + this.defaultDecodeMode = defaultDecodeMode; + return self(); + } + B setExecutorProvider(ExecutorProvider executorProvider) { this.executorProvider = executorProvider; return self(); @@ -411,8 +417,8 @@ void initTransaction() { TraceWrapper tracer; private final int defaultPrefetchChunks; private final QueryOptions defaultQueryOptions; - private final DirectedReadOptions defaultDirectedReadOptions; + private final DecodeMode defaultDecodeMode; private final Clock clock; @GuardedBy("lock") @@ -438,6 +444,7 @@ void initTransaction() { this.defaultPrefetchChunks = builder.defaultPrefetchChunks; this.defaultQueryOptions = builder.defaultQueryOptions; this.defaultDirectedReadOptions = builder.defaultDirectedReadOption; + this.defaultDecodeMode = builder.defaultDecodeMode; this.span = builder.span; this.executorProvider = builder.executorProvider; this.clock = builder.clock; @@ -727,7 +734,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, options.hasDecodeMode() ? options.decodeMode() : defaultDecodeMode); } /** @@ -871,7 +879,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, readOptions.hasDecodeMode() ? readOptions.decodeMode() : defaultDecodeMode); } private Struct consumeSingleRow(ResultSet resultSet) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index 664cde1edbb..22fb9f710c1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -60,6 +60,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(TimestampBound bound) { sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()) + .setDefaultDecodeMode(sessionClient.getSpanner().getDefaultDecodeMode()) .setDefaultDirectedReadOptions( sessionClient.getSpanner().getOptions().getDirectedReadOptions()) .setSpan(sessionClient.getSpanner().getTracer().getCurrentSpan()) @@ -81,6 +82,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(BatchTransactionId batc sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()) + .setDefaultDecodeMode(sessionClient.getSpanner().getDefaultDecodeMode()) .setDefaultDirectedReadOptions( sessionClient.getSpanner().getOptions().getDirectedReadOptions()) .setSpan(sessionClient.getSpanner().getTracer().getCurrentSpan()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java new file mode 100644 index 00000000000..c1bea9a3ce1 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +/** Specifies how and when to decode a value from protobuf to a plain Java object. */ +public enum DecodeMode { + /** + * Decodes all columns of a row directly when a {@link ResultSet} is advanced to the next row with + * {@link ResultSet#next()} + */ + DIRECT, + /** + * Decodes all columns of a row the first time a {@link ResultSet} value is retrieved from the + * row. + */ + LAZY_PER_ROW, + /** + * Decodes a columns of a row the first time the value of that column is retrieved from the row. + */ + LAZY_PER_COL, +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index c29282879ed..18ecbeceb0f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -23,7 +23,7 @@ import com.google.spanner.v1.ResultSetStats; /** Forwarding implementation of ResultSet that forwards all calls to a delegate. */ -public class ForwardingResultSet extends ForwardingStructReader implements ResultSet { +public class ForwardingResultSet extends ForwardingStructReader implements ProtobufResultSet { private Supplier delegate; @@ -55,6 +55,22 @@ public boolean next() throws SpannerException { return delegate.get().next(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + ResultSet resultSetDelegate = delegate.get(); + return (resultSetDelegate instanceof ProtobufResultSet) + && ((ProtobufResultSet) resultSetDelegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + ResultSet resultSetDelegate = delegate.get(); + Preconditions.checkState( + resultSetDelegate instanceof ProtobufResultSet, + "The result set does not support protobuf values"); + return ((ProtobufResultSet) resultSetDelegate).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { return delegate.get().getCurrentRowAsStruct(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java index b5d64bce3bb..37a4792ad87 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.Value; import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; @@ -28,9 +29,10 @@ import javax.annotation.Nullable; @VisibleForTesting -class GrpcResultSet extends AbstractResultSet> { +class GrpcResultSet extends AbstractResultSet> implements ProtobufResultSet { private final GrpcValueIterator iterator; private final Listener listener; + private final DecodeMode decodeMode; private ResultSetMetadata metadata; private GrpcStruct currRow; private SpannerException error; @@ -38,8 +40,26 @@ class GrpcResultSet extends AbstractResultSet> { private boolean closed; GrpcResultSet(CloseableIterator iterator, Listener listener) { + this(iterator, listener, DecodeMode.DIRECT); + } + + GrpcResultSet( + CloseableIterator iterator, Listener listener, DecodeMode decodeMode) { this.iterator = new GrpcValueIterator(iterator); this.listener = listener; + this.decodeMode = decodeMode; + } + + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed && currRow != null && currRow.canGetProtoValue(columnIndex); + } + + @Override + public Value getProtobufValue(int columnIndex) { + checkState(!closed, "ResultSet is closed"); + checkState(currRow != null, "next() call required"); + return currRow.getProtoValueInternal(columnIndex); } @Override @@ -65,7 +85,7 @@ public boolean next() throws SpannerException { throw SpannerExceptionFactory.newSpannerException( ErrorCode.FAILED_PRECONDITION, AbstractReadContext.NO_TRANSACTION_RETURNED_MSG); } - currRow = new GrpcStruct(iterator.type(), new ArrayList<>()); + currRow = new GrpcStruct(iterator.type(), new ArrayList<>(), decodeMode); } boolean hasNext = currRow.consumeRow(iterator); if (!hasNext) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java index 0d8a5545a90..152c82e9ca9 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java @@ -27,6 +27,7 @@ import com.google.cloud.spanner.AbstractResultSet.Float64Array; import com.google.cloud.spanner.AbstractResultSet.Int64Array; import com.google.cloud.spanner.AbstractResultSet.LazyByteArray; +import com.google.cloud.spanner.Type.Code; import com.google.cloud.spanner.Type.StructField; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -42,6 +43,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Base64; +import java.util.BitSet; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -54,6 +56,9 @@ class GrpcStruct extends Struct implements Serializable { private final Type type; private final List rowData; + private final DecodeMode decodeMode; + private final BitSet colDecoded; + private boolean rowDecoded; /** * Builds an immutable version of this struct using {@link Struct#newBuilder()} which is used as a @@ -181,9 +186,28 @@ private Object writeReplace() { return builder.build(); } - GrpcStruct(Type type, List rowData) { + GrpcStruct(Type type, List rowData, DecodeMode decodeMode) { + this( + type, + rowData, + decodeMode, + /* rowDecoded = */ false, + /* colDecoded = */ decodeMode == DecodeMode.LAZY_PER_COL + ? new BitSet(type.getStructFields().size()) + : null); + } + + private GrpcStruct( + Type type, + List rowData, + DecodeMode decodeMode, + boolean rowDecoded, + BitSet colDecoded) { this.type = type; this.rowData = rowData; + this.decodeMode = decodeMode; + this.rowDecoded = rowDecoded; + this.colDecoded = colDecoded; } @Override @@ -193,6 +217,11 @@ public String toString() { boolean consumeRow(Iterator iterator) { rowData.clear(); + if (decodeMode == DecodeMode.LAZY_PER_ROW) { + rowDecoded = false; + } else if (decodeMode == DecodeMode.LAZY_PER_COL) { + colDecoded.clear(); + } if (!iterator.hasNext()) { return false; } @@ -203,7 +232,11 @@ boolean consumeRow(Iterator iterator) { "Invalid value stream: end of stream reached before row is complete"); } com.google.protobuf.Value value = iterator.next(); - rowData.add(decodeValue(fieldType.getType(), value)); + if (decodeMode == DecodeMode.DIRECT) { + rowData.add(decodeValue(fieldType.getType(), value)); + } else { + rowData.add(value); + } } return true; } @@ -266,7 +299,7 @@ private static Struct decodeStructValue(Type structType, ListValue structValue) for (int i = 0; i < fieldTypes.size(); ++i) { fields.add(decodeValue(fieldTypes.get(i).getType(), fieldValues.get(i))); } - return new GrpcStruct(structType, fields); + return new GrpcStruct(structType, fields, DecodeMode.DIRECT); } static Object decodeArrayValue(Type elementType, ListValue listValue) { @@ -310,7 +343,12 @@ private static void checkType( } Struct immutableCopy() { - return new GrpcStruct(type, new ArrayList<>(rowData)); + return new GrpcStruct( + type, + new ArrayList<>(rowData), + this.decodeMode, + this.rowDecoded, + this.colDecoded == null ? null : (BitSet) this.colDecoded.clone()); } @Override @@ -320,6 +358,10 @@ public Type getType() { @Override public boolean isNull(int columnIndex) { + if ((decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) + || (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex))) { + return ((com.google.protobuf.Value) rowData.get(columnIndex)).hasNullValue(); + } return rowData.get(columnIndex) == null; } @@ -355,64 +397,123 @@ protected T getProtoEnumInternal( @Override protected boolean getBooleanInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Boolean) rowData.get(columnIndex); } @Override protected long getLongInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Long) rowData.get(columnIndex); } @Override protected double getDoubleInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Double) rowData.get(columnIndex); } @Override protected BigDecimal getBigDecimalInternal(int columnIndex) { + ensureDecoded(columnIndex); return (BigDecimal) rowData.get(columnIndex); } @Override protected String getStringInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected String getJsonInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected String getPgJsonbInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected ByteArray getBytesInternal(int columnIndex) { + ensureDecoded(columnIndex); return getLazyBytesInternal(columnIndex).getByteArray(); } LazyByteArray getLazyBytesInternal(int columnIndex) { + ensureDecoded(columnIndex); return (LazyByteArray) rowData.get(columnIndex); } @Override protected Timestamp getTimestampInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Timestamp) rowData.get(columnIndex); } @Override protected Date getDateInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Date) rowData.get(columnIndex); } + private boolean isUnrecognizedType(int columnIndex) { + return type.getStructFields().get(columnIndex).getType().getCode() == Code.UNRECOGNIZED; + } + + boolean canGetProtoValue(int columnIndex) { + return isUnrecognizedType(columnIndex) + || (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) + || (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)); + } + protected com.google.protobuf.Value getProtoValueInternal(int columnIndex) { + checkProtoValueSupported(columnIndex); return (com.google.protobuf.Value) rowData.get(columnIndex); } + private void checkProtoValueSupported(int columnIndex) { + // Unrecognized types are returned as protobuf values. + if (isUnrecognizedType(columnIndex)) { + return; + } + Preconditions.checkState( + decodeMode != DecodeMode.DIRECT, + "Getting proto value is not supported when DecodeMode#DIRECT is used."); + Preconditions.checkState( + !(decodeMode == DecodeMode.LAZY_PER_ROW && rowDecoded), + "Getting proto value after the row has been decoded is not supported."); + Preconditions.checkState( + !(decodeMode == DecodeMode.LAZY_PER_COL && colDecoded.get(columnIndex)), + "Getting proto value after the column has been decoded is not supported."); + } + + private void ensureDecoded(int columnIndex) { + if (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) { + for (int i = 0; i < rowData.size(); i++) { + rowData.set( + i, + decodeValue( + type.getStructFields().get(i).getType(), + (com.google.protobuf.Value) rowData.get(i))); + } + rowDecoded = true; + } else if (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)) { + rowData.set( + columnIndex, + decodeValue( + type.getStructFields().get(columnIndex).getType(), + (com.google.protobuf.Value) rowData.get(columnIndex))); + colDecoded.set(columnIndex); + } + } + @Override protected Value getValueInternal(int columnIndex) { + ensureDecoded(columnIndex); final List structFields = getType().getStructFields(); final StructField structField = structFields.get(columnIndex); final Type columnType = structField.getType(); @@ -423,7 +524,8 @@ protected Value getValueInternal(int columnIndex) { case INT64: return Value.int64(isNull ? null : getLongInternal(columnIndex)); case ENUM: - return Value.protoEnum(getLongInternal(columnIndex), columnType.getProtoTypeFqn()); + return Value.protoEnum( + isNull ? null : getLongInternal(columnIndex), columnType.getProtoTypeFqn()); case NUMERIC: return Value.numeric(isNull ? null : getBigDecimalInternal(columnIndex)); case PG_NUMERIC: @@ -439,7 +541,8 @@ protected Value getValueInternal(int columnIndex) { case BYTES: return Value.internalBytes(isNull ? null : getLazyBytesInternal(columnIndex)); case PROTO: - return Value.protoMessage(getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); + return Value.protoMessage( + isNull ? null : getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); case TIMESTAMP: return Value.timestamp(isNull ? null : getTimestampInternal(columnIndex)); case DATE: @@ -494,11 +597,13 @@ protected Value getValueInternal(int columnIndex) { @Override protected Struct getStructInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Struct) rowData.get(columnIndex); } @Override protected boolean[] getBooleanArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); @SuppressWarnings("unchecked") // We know ARRAY produces a List. List values = (List) rowData.get(columnIndex); boolean[] r = new boolean[values.size()]; @@ -514,44 +619,52 @@ protected boolean[] getBooleanArrayInternal(int columnIndex) { @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBooleanListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override protected long[] getLongArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); return getLongListInternal(columnIndex).toPrimitiveArray(columnIndex); } @Override protected Int64Array getLongListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Int64Array) rowData.get(columnIndex); } @Override protected double[] getDoubleArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); return getDoubleListInternal(columnIndex).toPrimitiveArray(columnIndex); } @Override protected Float64Array getDoubleListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Float64Array) rowData.get(columnIndex); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBigDecimalListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (List) rowData.get(columnIndex); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getStringListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getJsonListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @@ -562,6 +675,7 @@ protected List getProtoMessageListInternal( Preconditions.checkNotNull( message, "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); + ensureDecoded(columnIndex); List bytesArray = (List) rowData.get(columnIndex); @@ -596,6 +710,7 @@ protected List getProtoEnumListInternal( int columnIndex, Function method) { Preconditions.checkNotNull( method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); + ensureDecoded(columnIndex); List enumIntArray = (List) rowData.get(columnIndex); List protoEnumList = new ArrayList<>(enumIntArray.size()); @@ -613,12 +728,14 @@ protected List getProtoEnumListInternal( @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getPgJsonbListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBytesListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Lists.transform( (List) rowData.get(columnIndex), l -> l == null ? null : l.getByteArray()); } @@ -626,18 +743,21 @@ protected List getBytesListInternal(int columnIndex) { @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getTimestampListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getDateListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY> produces a List. protected List getStructListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index 57feabbfcca..76d0f24225a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -243,6 +243,10 @@ public static ReadAndQueryOption directedRead(DirectedReadOptions directedReadOp return new DirectedReadOption(directedReadOptions); } + public static ReadAndQueryOption decodeMode(DecodeMode decodeMode) { + return new DecodeOption(decodeMode); + } + /** Option to request {@link CommitStats} for read/write transactions. */ static final class CommitStatsOption extends InternalOption implements TransactionOption { @Override @@ -374,6 +378,19 @@ void appendToOptions(Options options) { } } + static final class DecodeOption extends InternalOption implements ReadAndQueryOption { + private final DecodeMode decodeMode; + + DecodeOption(DecodeMode decodeMode) { + this.decodeMode = Preconditions.checkNotNull(decodeMode, "DecodeMode cannot be null"); + } + + @Override + void appendToOptions(Options options) { + options.decodeMode = decodeMode; + } + } + private boolean withCommitStats; private Duration maxCommitDelay; @@ -391,6 +408,7 @@ void appendToOptions(Options options) { private Boolean withOptimisticLock; private Boolean dataBoostEnabled; private DirectedReadOptions directedReadOptions; + private DecodeMode decodeMode; // Construction is via factory methods below. private Options() {} @@ -507,6 +525,14 @@ DirectedReadOptions directedReadOptions() { return directedReadOptions; } + boolean hasDecodeMode() { + return decodeMode != null; + } + + DecodeMode decodeMode() { + return decodeMode; + } + @Override public String toString() { StringBuilder b = new StringBuilder(); @@ -552,6 +578,9 @@ public String toString() { if (directedReadOptions != null) { b.append("directedReadOptions: ").append(directedReadOptions).append(' '); } + if (decodeMode != null) { + b.append("decodeMode: ").append(decodeMode).append(' '); + } return b.toString(); } @@ -640,6 +669,9 @@ public int hashCode() { if (directedReadOptions != null) { result = 31 * result + directedReadOptions.hashCode(); } + if (decodeMode != null) { + result = 31 * result + decodeMode.hashCode(); + } return result; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java new file mode 100644 index 00000000000..bbd8c41291f --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.InternalApi; +import com.google.protobuf.Value; + +/** Interface for {@link ResultSet}s that can return a protobuf value. */ +@InternalApi +public interface ProtobufResultSet extends ResultSet { + + /** Returns true if the protobuf value for the given column is still available. */ + boolean canGetProtobufValue(int columnIndex); + + /** + * Returns the column value as a protobuf value. + * + *

This is an internal method not intended for external usage. + * + *

This method may only be called before the column value has been decoded to a plain Java + * object. This means that the {@link DecodeMode} that is used for the {@link ResultSet} must be + * one of {@link DecodeMode#LAZY_PER_ROW} and {@link DecodeMode#LAZY_PER_COL}, and that the + * corresponding {@link ResultSet#getValue(int)}, {@link ResultSet#getBoolean(int)}, ... method + * may not yet have been called for the column. + */ + @InternalApi + Value getProtobufValue(int columnIndex); +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java index d55d4091b9f..a6cc7c729e5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java @@ -109,7 +109,7 @@ public ResultSet get() { } } - private static class PrePopulatedResultSet implements ResultSet { + private static class PrePopulatedResultSet implements ProtobufResultSet { private final List rows; private final Type type; private int index = -1; @@ -137,6 +137,19 @@ public boolean next() throws SpannerException { return ++index < rows.size(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed && index >= 0 && index < rows.size(); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + Preconditions.checkState(!closed, "ResultSet is closed"); + Preconditions.checkState(index >= 0, "Must be preceded by a next() call"); + Preconditions.checkElementIndex(index, rows.size(), "All rows have been yielded"); + return getValue(columnIndex).toProto(); + } + @Override public Struct getCurrentRowAsStruct() { Preconditions.checkState(!closed, "ResultSet is closed"); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 29928f61cec..81b00001105 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -263,6 +263,7 @@ public ReadContext singleUse(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -284,6 +285,7 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -305,6 +307,7 @@ public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -423,6 +426,7 @@ TransactionContextImpl newTransaction(Options options) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setSpan(currentSpan) .setTracer(tracer) .setExecutorProvider(spanner.getAsyncExecutorProvider()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 326a51d803e..8fe06f76cc8 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -151,6 +151,10 @@ int getDefaultPrefetchChunks() { return getOptions().getPrefetchChunks(); } + DecodeMode getDefaultDecodeMode() { + return getOptions().getDecodeMode(); + } + /** Returns the default query options that should be used for the specified database. */ QueryOptions getDefaultQueryOptions(DatabaseId databaseId) { return getOptions().getDefaultQueryOptions(databaseId); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index 9c6044aa938..a16be179ce3 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -113,6 +113,7 @@ public class SpannerOptions extends ServiceOptions { private final GrpcInterceptorProvider interceptorProvider; private final SessionPoolOptions sessionPoolOptions; private final int prefetchChunks; + private final DecodeMode decodeMode; private final int numChannels; private final String transportChannelExecutorThreadNameFormat; private final String databaseRole; @@ -616,6 +617,7 @@ protected SpannerOptions(Builder builder) { ? builder.sessionPoolOptions : SessionPoolOptions.newBuilder().build(); prefetchChunks = builder.prefetchChunks; + decodeMode = builder.decodeMode; databaseRole = builder.databaseRole; sessionLabels = builder.sessionLabels; try { @@ -704,6 +706,9 @@ public static class Builder extends ServiceOptions.Builder { static final int DEFAULT_PREFETCH_CHUNKS = 4; static final QueryOptions DEFAULT_QUERY_OPTIONS = QueryOptions.getDefaultInstance(); + // TODO: Set the default to DecodeMode.DIRECT before merging to keep the current default. + // It is currently set to LAZY_PER_COL so it is used in all tests. + static final DecodeMode DEFAULT_DECODE_MODE = DecodeMode.LAZY_PER_COL; static final RetrySettings DEFAULT_ADMIN_REQUESTS_LIMIT_EXCEEDED_RETRY_SETTINGS = RetrySettings.newBuilder() .setInitialRetryDelay(Duration.ofSeconds(5L)) @@ -730,6 +735,7 @@ public static class Builder private String transportChannelExecutorThreadNameFormat = "Cloud-Spanner-TransportChannel-%d"; private int prefetchChunks = DEFAULT_PREFETCH_CHUNKS; + private DecodeMode decodeMode = DEFAULT_DECODE_MODE; private SessionPoolOptions sessionPoolOptions; private String databaseRole; private ImmutableMap sessionLabels; @@ -797,6 +803,7 @@ protected Builder() { options.transportChannelExecutorThreadNameFormat; this.sessionPoolOptions = options.sessionPoolOptions; this.prefetchChunks = options.prefetchChunks; + this.decodeMode = options.decodeMode; this.databaseRole = options.databaseRole; this.sessionLabels = options.sessionLabels; this.spannerStubSettingsBuilder = options.spannerStubSettings.toBuilder(); @@ -1224,6 +1231,15 @@ public Builder setPrefetchChunks(int prefetchChunks) { return this; } + /** + * Specifies how values that are returned from a query should be decoded and converted from + * protobuf values into plain Java objects. + */ + public Builder setDecodeMode(DecodeMode decodeMode) { + this.decodeMode = decodeMode; + return this; + } + @Override public Builder setHost(String host) { super.setHost(host); @@ -1568,6 +1584,10 @@ public int getPrefetchChunks() { return prefetchChunks; } + public DecodeMode getDecodeMode() { + return decodeMode; + } + public static GrpcTransportOptions getDefaultGrpcTransportOptions() { return GrpcTransportOptions.newBuilder().build(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java index a845eb118bf..3f0155e4a5e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java @@ -1518,6 +1518,10 @@ void valueToString(StringBuilder b) { @Override boolean valueEquals(Value v) { + // NaN == NaN always returns false, so we need a custom check. + if (Double.isNaN(this.value)) { + return Double.isNaN(((Float64Impl) v).value); + } return ((Float64Impl) v).value == value; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java index dc373cf03bd..c642d7e505a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java @@ -16,28 +16,28 @@ package com.google.cloud.spanner.connection; -import com.google.cloud.ByteArray; -import com.google.cloud.Date; -import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbortedException; +import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.Options.QueryOption; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; -import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.Code; +import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.ReadWriteTransaction.RetriableStatement; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.hash.Funnel; import com.google.common.hash.HashCode; -import com.google.common.hash.HashFunction; -import com.google.common.hash.Hasher; -import com.google.common.hash.Hashing; -import com.google.common.hash.PrimitiveSink; -import java.math.BigDecimal; -import java.util.Objects; +import com.google.protobuf.Value; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicLong; @@ -71,11 +71,11 @@ class ChecksumResultSet extends ReplaceableForwardingResultSet implements Retria private final ParsedStatement statement; private final AnalyzeMode analyzeMode; private final QueryOption[] options; - private final ChecksumResultSet.ChecksumCalculator checksumCalculator = new ChecksumCalculator(); + private final ChecksumCalculator checksumCalculator = new ChecksumCalculator(); ChecksumResultSet( ReadWriteTransaction transaction, - ResultSet delegate, + ProtobufResultSet delegate, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { @@ -91,6 +91,13 @@ class ChecksumResultSet extends ReplaceableForwardingResultSet implements Retria this.options = options; } + @Override + public Value getProtobufValue(int columnIndex) { + // We can safely cast to ProtobufResultSet here, as the constructor of this class only accepts + // instances of ProtobufResultSet. + return ((ProtobufResultSet) getDelegate()).getProtobufValue(columnIndex); + } + /** Simple {@link Callable} for calling {@link ResultSet#next()} */ private final class NextCallable implements Callable { @Override @@ -102,7 +109,7 @@ public Boolean call() { boolean res = ChecksumResultSet.super.next(); // Only update the checksum if there was another row to be consumed. if (res) { - checksumCalculator.calculateNextChecksum(getCurrentRowAsStruct()); + checksumCalculator.calculateNextChecksum(ChecksumResultSet.this); } numberOfNextCalls.incrementAndGet(); return res; @@ -118,8 +125,9 @@ public boolean next() { } @VisibleForTesting - HashCode getChecksum() { - // HashCode is immutable and can be safely returned. + byte[] getChecksum() { + // Getting the checksum from the checksumCalculator will create a clone of the current digest + // and return the checksum from the clone, so it is safe to return this value. return checksumCalculator.getChecksum(); } @@ -132,8 +140,8 @@ HashCode getChecksum() { @Override public void retry(AbortedException aborted) throws AbortedException { // Execute the same query and consume the result set to the same point as the original. - ChecksumResultSet.ChecksumCalculator newChecksumCalculator = new ChecksumCalculator(); - ResultSet resultSet = null; + ChecksumCalculator newChecksumCalculator = new ChecksumCalculator(); + ProtobufResultSet resultSet = null; long counter = 0L; try { transaction @@ -150,7 +158,7 @@ public void retry(AbortedException aborted) throws AbortedException { statement, StatementExecutionStep.RETRY_NEXT_ON_RESULT_SET, transaction); next = resultSet.next(); if (next) { - newChecksumCalculator.calculateNextChecksum(resultSet.getCurrentRowAsStruct()); + newChecksumCalculator.calculateNextChecksum(resultSet); } counter++; } @@ -168,9 +176,9 @@ public void retry(AbortedException aborted) throws AbortedException { throw e; } // Check that we have the same number of rows and the same checksum. - HashCode newChecksum = newChecksumCalculator.getChecksum(); - HashCode currentChecksum = checksumCalculator.getChecksum(); - if (counter == numberOfNextCalls.get() && Objects.equals(newChecksum, currentChecksum)) { + byte[] newChecksum = newChecksumCalculator.getChecksum(); + byte[] currentChecksum = checksumCalculator.getChecksum(); + if (counter == numberOfNextCalls.get() && Arrays.equals(newChecksum, currentChecksum)) { // Checksum is ok, we only need to replace the delegate result set if it's still open. if (isClosed()) { resultSet.close(); @@ -184,222 +192,165 @@ public void retry(AbortedException aborted) throws AbortedException { } } - /** Calculates and keeps the current checksum of a {@link ChecksumResultSet} */ + /** + * Calculates a running checksum for all the data that has been seen sofar in this result set. The + * calculation is performed on the protobuf values that were returned by Cloud Spanner, which + * means that no decoding of the results is needed (or allowed!) before calculating the checksum. + * This is more efficient, both in terms of CPU usage and memory consumption, especially if the + * consumer of the result set does not read all values, or is only reading the underlying protobuf + * values. + */ private static final class ChecksumCalculator { - private static final HashFunction SHA256_FUNCTION = Hashing.sha256(); - private HashCode currentChecksum; + // Use a buffer of max 1Mb to hash string data. This means that strings of up to 1Mb in size + // will be hashed in one go, while strings larger than 1Mb will be chunked into pieces of at + // most 1Mb and then fed into the digest. The digest internally creates a copy of the string + // that is being hashed, so chunking large strings prevents them from being loaded into memory + // twice. + private static final int MAX_BUFFER_SIZE = 1 << 20; - private void calculateNextChecksum(Struct row) { - Hasher hasher = SHA256_FUNCTION.newHasher(); - if (currentChecksum != null) { - hasher.putBytes(currentChecksum.asBytes()); + private boolean firstRow = true; + private final MessageDigest digest; + private ByteBuffer buffer; + private ByteBuffer float64Buffer; + + ChecksumCalculator() { + try { + // This is safe, as all Java implementations are required to have MD5 implemented. + // See https://docs.oracle.com/javase/8/docs/api/java/security/MessageDigest.html + // MD5 requires less CPU power than SHA-256, and still offers a low enough collision + // probability for the use case at hand here. + digest = MessageDigest.getInstance("MD5"); + } catch (Throwable t) { + throw SpannerExceptionFactory.asSpannerException(t); } - hasher.putObject(row, StructFunnel.INSTANCE); - currentChecksum = hasher.hash(); } - private HashCode getChecksum() { - return currentChecksum; + private byte[] getChecksum() { + try { + // This is safe, as the MD5 MessageDigest is known to be cloneable. + MessageDigest clone = (MessageDigest) digest.clone(); + return clone.digest(); + } catch (CloneNotSupportedException e) { + throw SpannerExceptionFactory.asSpannerException(e); + } } - } - /** - * A {@link Funnel} implementation for calculating a {@link HashCode} for each row in a {@link - * ResultSet}. - */ - private enum StructFunnel implements Funnel { - INSTANCE; - private static final String NULL = "null"; - - @Override - public void funnel(Struct row, PrimitiveSink into) { - for (int i = 0; i < row.getColumnCount(); i++) { - if (row.isNull(i)) { - funnelValue(Code.STRING, null, into); + private void calculateNextChecksum(ProtobufResultSet resultSet) { + if (firstRow) { + for (StructField field : resultSet.getType().getStructFields()) { + digest.update(field.getType().toString().getBytes(StandardCharsets.UTF_8)); + } + } + for (int col = 0; col < resultSet.getColumnCount(); col++) { + Type type = resultSet.getColumnType(col); + if (resultSet.canGetProtobufValue(col)) { + Value value = resultSet.getProtobufValue(col); + digest.update((byte) value.getKindCase().getNumber()); + pushValue(type, value); } else { - Code type = row.getColumnType(i).getCode(); - switch (type) { - case ARRAY: - funnelArray(row.getColumnType(i).getArrayElementType().getCode(), row, i, into); - break; - case BOOL: - funnelValue(type, row.getBoolean(i), into); - break; - case BYTES: - case PROTO: - funnelValue(type, row.getBytes(i), into); - break; - case DATE: - funnelValue(type, row.getDate(i), into); - break; - case FLOAT64: - funnelValue(type, row.getDouble(i), into); - break; - case NUMERIC: - funnelValue(type, row.getBigDecimal(i), into); - break; - case PG_NUMERIC: - funnelValue(type, row.getString(i), into); - break; - case INT64: - case ENUM: - funnelValue(type, row.getLong(i), into); - break; - case STRING: - funnelValue(type, row.getString(i), into); - break; - case JSON: - funnelValue(type, row.getJson(i), into); - break; - case PG_JSONB: - funnelValue(type, row.getPgJsonb(i), into); - break; - case TIMESTAMP: - funnelValue(type, row.getTimestamp(i), into); - break; - - case STRUCT: - default: - throw new IllegalArgumentException("unsupported row type"); - } + // This will normally not happen, unless the user explicitly sets the decoding mode to + // DIRECT for a query in a read/write transaction. The default decoding mode in the + // Connection API is set to LAZY_PER_COL. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "Failed to get the underlying protobuf value for the column " + + resultSet.getMetadata().getRowType().getFields(col).getName() + + ". " + + "Executing queries with DecodeMode#DIRECT is not supported in read/write transactions."); } } + firstRow = false; } - private void funnelArray( - Code arrayElementType, Struct row, int columnIndex, PrimitiveSink into) { - funnelValue(Code.STRING, "BeginArray", into); - switch (arrayElementType) { - case BOOL: - into.putInt(row.getBooleanList(columnIndex).size()); - for (Boolean value : row.getBooleanList(columnIndex)) { - funnelValue(Code.BOOL, value, into); - } + private void pushValue(Type type, Value value) { + // Protobuf Value has a very limited set of possible types of values. All Cloud Spanner types + // are mapped to one of the protobuf values listed here, meaning that no changes are needed to + // this calculation when a new type is added to Cloud Spanner. + switch (value.getKindCase()) { + case NULL_VALUE: + // nothing needed, writing the KindCase is enough. break; - case BYTES: - case PROTO: - into.putInt(row.getBytesList(columnIndex).size()); - for (ByteArray value : row.getBytesList(columnIndex)) { - funnelValue(Code.BYTES, value, into); - } + case BOOL_VALUE: + digest.update(value.getBoolValue() ? (byte) 1 : 0); break; - case DATE: - into.putInt(row.getDateList(columnIndex).size()); - for (Date value : row.getDateList(columnIndex)) { - funnelValue(Code.DATE, value, into); - } + case STRING_VALUE: + putString(value.getStringValue()); break; - case FLOAT64: - into.putInt(row.getDoubleList(columnIndex).size()); - for (Double value : row.getDoubleList(columnIndex)) { - funnelValue(Code.FLOAT64, value, into); + case NUMBER_VALUE: + if (float64Buffer == null) { + // Create an 8-byte buffer that can be re-used for all float64 values in this result + // set. + float64Buffer = ByteBuffer.allocate(Double.BYTES); + } else { + float64Buffer.clear(); } + float64Buffer.putDouble(value.getNumberValue()); + float64Buffer.flip(); + digest.update(float64Buffer); break; - case NUMERIC: - into.putInt(row.getBigDecimalList(columnIndex).size()); - for (BigDecimal value : row.getBigDecimalList(columnIndex)) { - funnelValue(Code.NUMERIC, value, into); + case LIST_VALUE: + if (type.getCode() == Code.ARRAY) { + for (Value item : value.getListValue().getValuesList()) { + digest.update((byte) item.getKindCase().getNumber()); + pushValue(type.getArrayElementType(), item); + } + } else { + // This should not be possible. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "List values that are not an ARRAY are not supported"); } break; - case PG_NUMERIC: - into.putInt(row.getStringList(columnIndex).size()); - for (String value : row.getStringList(columnIndex)) { - funnelValue(Code.STRING, value, into); + case STRUCT_VALUE: + if (type.getCode() == Code.STRUCT) { + for (int col = 0; col < type.getStructFields().size(); col++) { + String name = type.getStructFields().get(col).getName(); + putString(name); + Value item = value.getStructValue().getFieldsMap().get(name); + digest.update((byte) item.getKindCase().getNumber()); + pushValue(type.getStructFields().get(col).getType(), item); + } + } else { + // This should not be possible. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "Struct values without a struct type are not supported"); } break; - case INT64: - case ENUM: - into.putInt(row.getLongList(columnIndex).size()); - for (Long value : row.getLongList(columnIndex)) { - funnelValue(Code.INT64, value, into); - } - break; - case STRING: - into.putInt(row.getStringList(columnIndex).size()); - for (String value : row.getStringList(columnIndex)) { - funnelValue(Code.STRING, value, into); - } - break; - case JSON: - into.putInt(row.getJsonList(columnIndex).size()); - for (String value : row.getJsonList(columnIndex)) { - funnelValue(Code.JSON, value, into); - } - break; - case PG_JSONB: - into.putInt(row.getPgJsonbList(columnIndex).size()); - for (String value : row.getPgJsonbList(columnIndex)) { - funnelValue(Code.PG_JSONB, value, into); - } - break; - case TIMESTAMP: - into.putInt(row.getTimestampList(columnIndex).size()); - for (Timestamp value : row.getTimestampList(columnIndex)) { - funnelValue(Code.TIMESTAMP, value, into); - } - break; - - case ARRAY: - case STRUCT: default: - throw new IllegalArgumentException("unsupported array element type"); + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.UNIMPLEMENTED, "Unsupported protobuf value: " + value.getKindCase()); } - funnelValue(Code.STRING, "EndArray", into); } - private void funnelValue(Code type, T value, PrimitiveSink into) { - // Include the type name in case the type of a column has changed. - into.putUnencodedChars(type.name()); - if (value == null) { - if (type == Code.BYTES || type == Code.STRING) { - // Put length -1 to distinguish from the string value 'null'. - into.putInt(-1); - } - into.putUnencodedChars(NULL); + /** Hashes a string value in blocks of max MAX_BUFFER_SIZE. */ + private void putString(String stringValue) { + int length = stringValue.length(); + if (buffer == null || (buffer.capacity() < MAX_BUFFER_SIZE && buffer.capacity() < length)) { + // Create a ByteBuffer with a maximum buffer size. + // This buffer is re-used for all string values in the result set. + buffer = ByteBuffer.allocate(Math.min(MAX_BUFFER_SIZE, length)); } else { - switch (type) { - case BOOL: - into.putBoolean((Boolean) value); - break; - case BYTES: - case PROTO: - ByteArray byteArray = (ByteArray) value; - into.putInt(byteArray.length()); - into.putBytes(byteArray.toByteArray()); - break; - case DATE: - Date date = (Date) value; - into.putInt(date.getYear()).putInt(date.getMonth()).putInt(date.getDayOfMonth()); - break; - case FLOAT64: - into.putDouble((Double) value); - break; - case NUMERIC: - String stringRepresentation = value.toString(); - into.putInt(stringRepresentation.length()); - into.putUnencodedChars(stringRepresentation); - break; - case INT64: - case ENUM: - into.putLong((Long) value); - break; - case PG_NUMERIC: - case STRING: - case JSON: - case PG_JSONB: - String stringValue = (String) value; - into.putInt(stringValue.length()); - into.putUnencodedChars(stringValue); - break; - case TIMESTAMP: - Timestamp timestamp = (Timestamp) value; - into.putLong(timestamp.getSeconds()).putInt(timestamp.getNanos()); - break; - case ARRAY: - case STRUCT: - default: - throw new IllegalArgumentException("invalid type for single value"); - } + buffer.clear(); + } + + // Wrap the string in a CharBuffer. This allows us to read from the string in sections without + // creating a new copy of (a part of) the string. E.g. using something like substring(..) + // would create a copy of that part of the string, using CharBuffer.wrap(..) does not. + CharBuffer source = CharBuffer.wrap(stringValue); + CharsetEncoder encoder = StandardCharsets.UTF_8.newEncoder(); + // source.hasRemaining() returns false when all the characters in the string have been + // processed. + while (source.hasRemaining()) { + // Encode the string into bytes and write them into the byte buffer. + // At most MAX_BUFFER_SIZE bytes will be written. + encoder.encode(source, buffer, false); + // Flip the buffer so we can read from the start. + buffer.flip(); + // Put the bytes from the buffer into the digest. + digest.update(buffer); + // Flip the buffer again, so we can repeat and write to the start of the buffer again. + buffer.flip(); } } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java index dff915e2cce..1b15ec50822 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java @@ -19,6 +19,7 @@ import com.google.cloud.ByteArray; import com.google.cloud.Date; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Struct; @@ -40,7 +41,7 @@ * to the actual query execution. It also ensures that any invalid query will throw an exception at * execution instead of the first next() call by a client. */ -class DirectExecuteResultSet implements ResultSet { +class DirectExecuteResultSet implements ProtobufResultSet { private static final String MISSING_NEXT_CALL = "Must be preceded by a next() call"; private final ResultSet delegate; private boolean nextCalledByClient = false; @@ -79,6 +80,21 @@ public boolean next() throws SpannerException { return initialNextResult; } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return nextCalledByClient + && delegate instanceof ProtobufResultSet + && ((ProtobufResultSet) delegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + Preconditions.checkState(nextCalledByClient, MISSING_NEXT_CALL); + Preconditions.checkState( + delegate instanceof ProtobufResultSet, "The result set does not support protobuf values"); + return ((ProtobufResultSet) delegate).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { Preconditions.checkState(nextCalledByClient, MISSING_NEXT_CALL); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java index e1fb87e4ade..6c4290c3b18 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java @@ -39,6 +39,7 @@ import com.google.cloud.spanner.Options.QueryOption; import com.google.cloud.spanner.Options.TransactionOption; import com.google.cloud.spanner.Options.UpdateOption; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ReadContext; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; @@ -427,7 +428,7 @@ public ApiFuture executeQueryAsync( statement, StatementExecutionStep.EXECUTE_STATEMENT, ReadWriteTransaction.this); - ResultSet delegate = + DirectExecuteResultSet delegate = DirectExecuteResultSet.ofResultSet( internalExecuteQuery(statement, analyzeMode, options)); return createAndAddRetryResultSet( @@ -797,7 +798,7 @@ T runWithRetry(Callable callable) throws SpannerException { * returns a retryable {@link ResultSet}. */ private ResultSet createAndAddRetryResultSet( - ResultSet resultSet, + ProtobufResultSet resultSet, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { @@ -1091,7 +1092,7 @@ interface RetriableStatement { /** Creates a {@link ChecksumResultSet} for this {@link ReadWriteTransaction}. */ @VisibleForTesting ChecksumResultSet createChecksumResultSet( - ResultSet delegate, + ProtobufResultSet delegate, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java index 7370551a46f..a8de14e5121 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java @@ -20,6 +20,7 @@ import com.google.cloud.Date; import com.google.cloud.Timestamp; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; @@ -42,7 +43,7 @@ * that is fetched using the new transaction. This is achieved by wrapping the returned result sets * in a {@link ReplaceableForwardingResultSet} that replaces its delegate after a transaction retry. */ -class ReplaceableForwardingResultSet implements ResultSet { +class ReplaceableForwardingResultSet implements ProtobufResultSet { private ResultSet delegate; private boolean closed; @@ -60,6 +61,10 @@ void replaceDelegate(ResultSet delegate) { this.delegate = delegate; } + protected ResultSet getDelegate() { + return this.delegate; + } + private void checkClosed() { if (closed) { throw SpannerExceptionFactory.newSpannerException( @@ -77,6 +82,21 @@ public boolean next() throws SpannerException { return delegate.next(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed + && delegate instanceof ProtobufResultSet + && ((ProtobufResultSet) delegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + checkClosed(); + Preconditions.checkState( + delegate instanceof ProtobufResultSet, "The result set does not support protobuf values"); + return ((ProtobufResultSet) getDelegate()).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { checkClosed(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java index 2a5a805c2c7..da8da78d92e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java @@ -17,6 +17,7 @@ package com.google.cloud.spanner.connection; import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.DecodeMode; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SessionPoolOptions; import com.google.cloud.spanner.Spanner; @@ -342,6 +343,9 @@ Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) { .setClientLibToken(MoreObjects.firstNonNull(key.userAgent, CONNECTION_API_CLIENT_LIB_TOKEN)) .setHost(key.host) .setProjectId(key.projectId) + // Use lazy decoding, so we can use the protobuf values for calculating the checksum that is + // needed for read/write transactions. + .setDecodeMode(DecodeMode.LAZY_PER_COL) .setDatabaseRole(options.getDatabaseRole()) .setCredentials(options.getCredentials()); builder.setSessionPoolOption(key.sessionPoolOptions); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java new file mode 100644 index 00000000000..6a6125e1dda --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.DecodeMode; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.Options; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Statement; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class DecodeModeTest extends AbstractMockServerTest { + + @After + public void clearRequests() { + mockSpanner.clearRequests(); + } + + @Test + public void testAllDecodeModes() { + int numRows = 10; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + String sql = "select * from random"; + Statement statement = Statement.of(sql); + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.query(statement, generator.generate())); + + try (Connection connection = createConnection()) { + for (boolean readonly : new boolean[] {true, false}) { + for (boolean autocommit : new boolean[] {true, false}) { + connection.setReadOnly(readonly); + connection.setAutocommit(autocommit); + + int receivedRows = 0; + // DecodeMode#DIRECT is not supported in read/write transactions, as the protobuf value is + // used for checksum calculation. + try (ResultSet direct = + connection.executeQuery( + statement, + !readonly && !autocommit + ? Options.decodeMode(DecodeMode.LAZY_PER_ROW) + : Options.decodeMode(DecodeMode.DIRECT)); + ResultSet lazyPerRow = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_ROW)); + ResultSet lazyPerCol = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_COL))) { + while (direct.next() && lazyPerRow.next() && lazyPerCol.next()) { + assertEquals(direct.getColumnCount(), lazyPerRow.getColumnCount()); + assertEquals(direct.getColumnCount(), lazyPerCol.getColumnCount()); + for (int col = 0; col < direct.getColumnCount(); col++) { + // Test getting the entire row as a struct both as the first thing we do, and as the + // last thing we do. This ensures that the method works as expected both when a row + // is lazily decoded by this method, and when it has already been decoded by another + // method. + if (col % 2 == 0) { + assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct()); + assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct()); + } + assertEquals(direct.isNull(col), lazyPerRow.isNull(col)); + assertEquals(direct.isNull(col), lazyPerCol.isNull(col)); + assertEquals(direct.getValue(col), lazyPerRow.getValue(col)); + assertEquals(direct.getValue(col), lazyPerCol.getValue(col)); + if (col % 2 == 1) { + assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct()); + assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct()); + } + } + receivedRows++; + } + assertEquals(numRows, receivedRows); + } + if (!autocommit) { + connection.commit(); + } + } + } + } + } + + @Test + public void testDecodeModeDirect_failsInReadWriteTransaction() { + int numRows = 1; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + String sql = "select * from random"; + Statement statement = Statement.of(sql); + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.query(statement, generator.generate())); + + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + try (ResultSet resultSet = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.DIRECT))) { + SpannerException exception = assertThrows(SpannerException.class, resultSet::next); + assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode()); + assertTrue( + exception.getMessage(), + exception + .getMessage() + .contains( + "Executing queries with DecodeMode#DIRECT is not supported in read/write transactions.")); + } + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java index 1e4f96d1568..b14f837ff7b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java @@ -59,6 +59,7 @@ public void testMethodCallBeforeNext() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -79,6 +80,7 @@ public void testMethodCallAfterClose() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -101,6 +103,7 @@ public void testMethodCallAfterNextHasReturnedFalse() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java index 3091364e17a..2067d36b5ea 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java @@ -34,6 +34,7 @@ import com.google.spanner.v1.TypeAnnotationCode; import com.google.spanner.v1.TypeCode; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -239,7 +240,10 @@ private void setRandomValue(Value.Builder builder, Type type) { if (dialect == Dialect.POSTGRESQL && randomNaN()) { builder.setStringValue("NaN"); } else { - builder.setStringValue(BigDecimal.valueOf(random.nextDouble()).toString()); + builder.setStringValue( + BigDecimal.valueOf(random.nextDouble()) + .setScale(9, RoundingMode.HALF_UP) + .toString()); } break; case INT64: diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java index 0f083fd1e50..8e643cf6e24 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java @@ -37,6 +37,7 @@ import com.google.cloud.spanner.CommitResponse; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.ResultSets; @@ -518,193 +519,197 @@ public void testChecksumResultSet() { .setGenre(Genre.FOLK) .build(); ProtocolMessageEnum protoEnumVal = Genre.ROCK; - ResultSet delegate1 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(BigDecimal.valueOf(550, 2)) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(BigDecimal.valueOf(750, 2)) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build())); + ProtobufResultSet delegate1 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(BigDecimal.valueOf(550, 2)) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(BigDecimal.valueOf(750, 2)) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build())); ChecksumResultSet rs1 = transaction.createChecksumResultSet(delegate1, parsedStatement, AnalyzeMode.NONE); - ResultSet delegate2 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build())); + ProtobufResultSet delegate2 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build())); ChecksumResultSet rs2 = transaction.createChecksumResultSet(delegate2, parsedStatement, AnalyzeMode.NONE); // rs1 and rs2 are equal, rs3 contains the same rows, but in a different order - ResultSet delegate3 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build(), - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build())); + ProtobufResultSet delegate3 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build(), + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build())); ChecksumResultSet rs3 = transaction.createChecksumResultSet(delegate3, parsedStatement, AnalyzeMode.NONE); // rs4 contains the same rows as rs1 and rs2, but also an additional row - ResultSet delegate4 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build(), - Struct.newBuilder() - .set("ID") - .to(3L) - .set("NAME") - .to("TEST 3") - .set("AMOUNT") - .to(new BigDecimal("9.99")) - .set("JSON") - .to(Value.json(emptyArrayJson)) - .set("PROTO") - .to(null, SingerInfo.getDescriptor()) - .set("PROTOENUM") - .to(Genre.POP) - .build())); + ProtobufResultSet delegate4 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build(), + Struct.newBuilder() + .set("ID") + .to(3L) + .set("NAME") + .to("TEST 3") + .set("AMOUNT") + .to(new BigDecimal("9.99")) + .set("JSON") + .to(Value.json(emptyArrayJson)) + .set("PROTO") + .to(null, SingerInfo.getDescriptor()) + .set("PROTOENUM") + .to(Genre.POP) + .build())); ChecksumResultSet rs4 = transaction.createChecksumResultSet(delegate4, parsedStatement, AnalyzeMode.NONE); @@ -736,44 +741,46 @@ public void testChecksumResultSetWithArray() { ParsedStatement parsedStatement = mock(ParsedStatement.class); Statement statement = Statement.of("SELECT * FROM FOO"); when(parsedStatement.getStatement()).thenReturn(statement); - ResultSet delegate1 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("PRICES", Type.array(Type.int64()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("PRICES") - .toInt64Array(new long[] {1L, 2L}) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("PRICES") - .toInt64Array(new long[] {3L, 4L}) - .build())); + ProtobufResultSet delegate1 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("PRICES", Type.array(Type.int64()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("PRICES") + .toInt64Array(new long[] {1L, 2L}) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("PRICES") + .toInt64Array(new long[] {3L, 4L}) + .build())); ChecksumResultSet rs1 = transaction.createChecksumResultSet(delegate1, parsedStatement, AnalyzeMode.NONE); - ResultSet delegate2 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("PRICES", Type.array(Type.int64()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("PRICES") - .toInt64Array(new long[] {1L, 2L}) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("PRICES") - .toInt64Array(new long[] {3L, 5L}) - .build())); + ProtobufResultSet delegate2 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("PRICES", Type.array(Type.int64()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("PRICES") + .toInt64Array(new long[] {1L, 2L}) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("PRICES") + .toInt64Array(new long[] {3L, 5L}) + .build())); ChecksumResultSet rs2 = transaction.createChecksumResultSet(delegate2, parsedStatement, AnalyzeMode.NONE); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java index bbb34675147..4617c47bc6b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java @@ -104,7 +104,14 @@ public void testReplace() { public void testMethodCallBeforeNext() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = - Arrays.asList("getStats", "getMetadata", "next", "close", "equals", "hashCode"); + Arrays.asList( + "canGetProtobufValue", + "getStats", + "getMetadata", + "next", + "close", + "equals", + "hashCode"); ReplaceableForwardingResultSet subject = createSubject(); // Test that all methods throw an IllegalStateException except the excluded methods when called // before a call to ResultSet#next(). @@ -116,6 +123,7 @@ public void testMethodCallAfterClose() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -140,6 +148,7 @@ public void testMethodCallAfterNextHasReturnedFalse() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next",