diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 2505b8b4985..ae18a1e4f5f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -70,6 +70,7 @@ abstract static class Builder, T extends AbstractReadCon private TraceWrapper tracer; private int defaultPrefetchChunks = SpannerOptions.Builder.DEFAULT_PREFETCH_CHUNKS; private QueryOptions defaultQueryOptions = SpannerOptions.Builder.DEFAULT_QUERY_OPTIONS; + private DecodeMode defaultDecodeMode = SpannerOptions.Builder.DEFAULT_DECODE_MODE; private DirectedReadOptions defaultDirectedReadOption; private ExecutorProvider executorProvider; private Clock clock = new Clock(); @@ -111,6 +112,11 @@ B setDefaultQueryOptions(QueryOptions defaultQueryOptions) { return self(); } + B setDefaultDecodeMode(DecodeMode defaultDecodeMode) { + this.defaultDecodeMode = defaultDecodeMode; + return self(); + } + B setExecutorProvider(ExecutorProvider executorProvider) { this.executorProvider = executorProvider; return self(); @@ -411,8 +417,8 @@ void initTransaction() { TraceWrapper tracer; private final int defaultPrefetchChunks; private final QueryOptions defaultQueryOptions; - private final DirectedReadOptions defaultDirectedReadOptions; + private final DecodeMode defaultDecodeMode; private final Clock clock; @GuardedBy("lock") @@ -438,6 +444,7 @@ void initTransaction() { this.defaultPrefetchChunks = builder.defaultPrefetchChunks; this.defaultQueryOptions = builder.defaultQueryOptions; this.defaultDirectedReadOptions = builder.defaultDirectedReadOption; + this.defaultDecodeMode = builder.defaultDecodeMode; this.span = builder.span; this.executorProvider = builder.executorProvider; this.clock = builder.clock; @@ -727,7 +734,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, options.hasDecodeMode() ? options.decodeMode() : defaultDecodeMode); } /** @@ -871,7 +879,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, readOptions.hasDecodeMode() ? readOptions.decodeMode() : defaultDecodeMode); } private Struct consumeSingleRow(ResultSet resultSet) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index 664cde1edbb..22fb9f710c1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -60,6 +60,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(TimestampBound bound) { sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()) + .setDefaultDecodeMode(sessionClient.getSpanner().getDefaultDecodeMode()) .setDefaultDirectedReadOptions( sessionClient.getSpanner().getOptions().getDirectedReadOptions()) .setSpan(sessionClient.getSpanner().getTracer().getCurrentSpan()) @@ -81,6 +82,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(BatchTransactionId batc sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()) + .setDefaultDecodeMode(sessionClient.getSpanner().getDefaultDecodeMode()) .setDefaultDirectedReadOptions( sessionClient.getSpanner().getOptions().getDirectedReadOptions()) .setSpan(sessionClient.getSpanner().getTracer().getCurrentSpan()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java new file mode 100644 index 00000000000..c1bea9a3ce1 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +/** Specifies how and when to decode a value from protobuf to a plain Java object. */ +public enum DecodeMode { + /** + * Decodes all columns of a row directly when a {@link ResultSet} is advanced to the next row with + * {@link ResultSet#next()} + */ + DIRECT, + /** + * Decodes all columns of a row the first time a {@link ResultSet} value is retrieved from the + * row. + */ + LAZY_PER_ROW, + /** + * Decodes a columns of a row the first time the value of that column is retrieved from the row. + */ + LAZY_PER_COL, +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index c29282879ed..18ecbeceb0f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -23,7 +23,7 @@ import com.google.spanner.v1.ResultSetStats; /** Forwarding implementation of ResultSet that forwards all calls to a delegate. */ -public class ForwardingResultSet extends ForwardingStructReader implements ResultSet { +public class ForwardingResultSet extends ForwardingStructReader implements ProtobufResultSet { private Supplier delegate; @@ -55,6 +55,22 @@ public boolean next() throws SpannerException { return delegate.get().next(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + ResultSet resultSetDelegate = delegate.get(); + return (resultSetDelegate instanceof ProtobufResultSet) + && ((ProtobufResultSet) resultSetDelegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + ResultSet resultSetDelegate = delegate.get(); + Preconditions.checkState( + resultSetDelegate instanceof ProtobufResultSet, + "The result set does not support protobuf values"); + return ((ProtobufResultSet) resultSetDelegate).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { return delegate.get().getCurrentRowAsStruct(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java index b5d64bce3bb..37a4792ad87 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.Value; import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; @@ -28,9 +29,10 @@ import javax.annotation.Nullable; @VisibleForTesting -class GrpcResultSet extends AbstractResultSet> { +class GrpcResultSet extends AbstractResultSet> implements ProtobufResultSet { private final GrpcValueIterator iterator; private final Listener listener; + private final DecodeMode decodeMode; private ResultSetMetadata metadata; private GrpcStruct currRow; private SpannerException error; @@ -38,8 +40,26 @@ class GrpcResultSet extends AbstractResultSet> { private boolean closed; GrpcResultSet(CloseableIterator iterator, Listener listener) { + this(iterator, listener, DecodeMode.DIRECT); + } + + GrpcResultSet( + CloseableIterator iterator, Listener listener, DecodeMode decodeMode) { this.iterator = new GrpcValueIterator(iterator); this.listener = listener; + this.decodeMode = decodeMode; + } + + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed && currRow != null && currRow.canGetProtoValue(columnIndex); + } + + @Override + public Value getProtobufValue(int columnIndex) { + checkState(!closed, "ResultSet is closed"); + checkState(currRow != null, "next() call required"); + return currRow.getProtoValueInternal(columnIndex); } @Override @@ -65,7 +85,7 @@ public boolean next() throws SpannerException { throw SpannerExceptionFactory.newSpannerException( ErrorCode.FAILED_PRECONDITION, AbstractReadContext.NO_TRANSACTION_RETURNED_MSG); } - currRow = new GrpcStruct(iterator.type(), new ArrayList<>()); + currRow = new GrpcStruct(iterator.type(), new ArrayList<>(), decodeMode); } boolean hasNext = currRow.consumeRow(iterator); if (!hasNext) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java index 0d8a5545a90..152c82e9ca9 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java @@ -27,6 +27,7 @@ import com.google.cloud.spanner.AbstractResultSet.Float64Array; import com.google.cloud.spanner.AbstractResultSet.Int64Array; import com.google.cloud.spanner.AbstractResultSet.LazyByteArray; +import com.google.cloud.spanner.Type.Code; import com.google.cloud.spanner.Type.StructField; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -42,6 +43,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Base64; +import java.util.BitSet; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -54,6 +56,9 @@ class GrpcStruct extends Struct implements Serializable { private final Type type; private final List rowData; + private final DecodeMode decodeMode; + private final BitSet colDecoded; + private boolean rowDecoded; /** * Builds an immutable version of this struct using {@link Struct#newBuilder()} which is used as a @@ -181,9 +186,28 @@ private Object writeReplace() { return builder.build(); } - GrpcStruct(Type type, List rowData) { + GrpcStruct(Type type, List rowData, DecodeMode decodeMode) { + this( + type, + rowData, + decodeMode, + /* rowDecoded = */ false, + /* colDecoded = */ decodeMode == DecodeMode.LAZY_PER_COL + ? new BitSet(type.getStructFields().size()) + : null); + } + + private GrpcStruct( + Type type, + List rowData, + DecodeMode decodeMode, + boolean rowDecoded, + BitSet colDecoded) { this.type = type; this.rowData = rowData; + this.decodeMode = decodeMode; + this.rowDecoded = rowDecoded; + this.colDecoded = colDecoded; } @Override @@ -193,6 +217,11 @@ public String toString() { boolean consumeRow(Iterator iterator) { rowData.clear(); + if (decodeMode == DecodeMode.LAZY_PER_ROW) { + rowDecoded = false; + } else if (decodeMode == DecodeMode.LAZY_PER_COL) { + colDecoded.clear(); + } if (!iterator.hasNext()) { return false; } @@ -203,7 +232,11 @@ boolean consumeRow(Iterator iterator) { "Invalid value stream: end of stream reached before row is complete"); } com.google.protobuf.Value value = iterator.next(); - rowData.add(decodeValue(fieldType.getType(), value)); + if (decodeMode == DecodeMode.DIRECT) { + rowData.add(decodeValue(fieldType.getType(), value)); + } else { + rowData.add(value); + } } return true; } @@ -266,7 +299,7 @@ private static Struct decodeStructValue(Type structType, ListValue structValue) for (int i = 0; i < fieldTypes.size(); ++i) { fields.add(decodeValue(fieldTypes.get(i).getType(), fieldValues.get(i))); } - return new GrpcStruct(structType, fields); + return new GrpcStruct(structType, fields, DecodeMode.DIRECT); } static Object decodeArrayValue(Type elementType, ListValue listValue) { @@ -310,7 +343,12 @@ private static void checkType( } Struct immutableCopy() { - return new GrpcStruct(type, new ArrayList<>(rowData)); + return new GrpcStruct( + type, + new ArrayList<>(rowData), + this.decodeMode, + this.rowDecoded, + this.colDecoded == null ? null : (BitSet) this.colDecoded.clone()); } @Override @@ -320,6 +358,10 @@ public Type getType() { @Override public boolean isNull(int columnIndex) { + if ((decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) + || (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex))) { + return ((com.google.protobuf.Value) rowData.get(columnIndex)).hasNullValue(); + } return rowData.get(columnIndex) == null; } @@ -355,64 +397,123 @@ protected T getProtoEnumInternal( @Override protected boolean getBooleanInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Boolean) rowData.get(columnIndex); } @Override protected long getLongInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Long) rowData.get(columnIndex); } @Override protected double getDoubleInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Double) rowData.get(columnIndex); } @Override protected BigDecimal getBigDecimalInternal(int columnIndex) { + ensureDecoded(columnIndex); return (BigDecimal) rowData.get(columnIndex); } @Override protected String getStringInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected String getJsonInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected String getPgJsonbInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected ByteArray getBytesInternal(int columnIndex) { + ensureDecoded(columnIndex); return getLazyBytesInternal(columnIndex).getByteArray(); } LazyByteArray getLazyBytesInternal(int columnIndex) { + ensureDecoded(columnIndex); return (LazyByteArray) rowData.get(columnIndex); } @Override protected Timestamp getTimestampInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Timestamp) rowData.get(columnIndex); } @Override protected Date getDateInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Date) rowData.get(columnIndex); } + private boolean isUnrecognizedType(int columnIndex) { + return type.getStructFields().get(columnIndex).getType().getCode() == Code.UNRECOGNIZED; + } + + boolean canGetProtoValue(int columnIndex) { + return isUnrecognizedType(columnIndex) + || (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) + || (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)); + } + protected com.google.protobuf.Value getProtoValueInternal(int columnIndex) { + checkProtoValueSupported(columnIndex); return (com.google.protobuf.Value) rowData.get(columnIndex); } + private void checkProtoValueSupported(int columnIndex) { + // Unrecognized types are returned as protobuf values. + if (isUnrecognizedType(columnIndex)) { + return; + } + Preconditions.checkState( + decodeMode != DecodeMode.DIRECT, + "Getting proto value is not supported when DecodeMode#DIRECT is used."); + Preconditions.checkState( + !(decodeMode == DecodeMode.LAZY_PER_ROW && rowDecoded), + "Getting proto value after the row has been decoded is not supported."); + Preconditions.checkState( + !(decodeMode == DecodeMode.LAZY_PER_COL && colDecoded.get(columnIndex)), + "Getting proto value after the column has been decoded is not supported."); + } + + private void ensureDecoded(int columnIndex) { + if (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) { + for (int i = 0; i < rowData.size(); i++) { + rowData.set( + i, + decodeValue( + type.getStructFields().get(i).getType(), + (com.google.protobuf.Value) rowData.get(i))); + } + rowDecoded = true; + } else if (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)) { + rowData.set( + columnIndex, + decodeValue( + type.getStructFields().get(columnIndex).getType(), + (com.google.protobuf.Value) rowData.get(columnIndex))); + colDecoded.set(columnIndex); + } + } + @Override protected Value getValueInternal(int columnIndex) { + ensureDecoded(columnIndex); final List structFields = getType().getStructFields(); final StructField structField = structFields.get(columnIndex); final Type columnType = structField.getType(); @@ -423,7 +524,8 @@ protected Value getValueInternal(int columnIndex) { case INT64: return Value.int64(isNull ? null : getLongInternal(columnIndex)); case ENUM: - return Value.protoEnum(getLongInternal(columnIndex), columnType.getProtoTypeFqn()); + return Value.protoEnum( + isNull ? null : getLongInternal(columnIndex), columnType.getProtoTypeFqn()); case NUMERIC: return Value.numeric(isNull ? null : getBigDecimalInternal(columnIndex)); case PG_NUMERIC: @@ -439,7 +541,8 @@ protected Value getValueInternal(int columnIndex) { case BYTES: return Value.internalBytes(isNull ? null : getLazyBytesInternal(columnIndex)); case PROTO: - return Value.protoMessage(getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); + return Value.protoMessage( + isNull ? null : getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); case TIMESTAMP: return Value.timestamp(isNull ? null : getTimestampInternal(columnIndex)); case DATE: @@ -494,11 +597,13 @@ protected Value getValueInternal(int columnIndex) { @Override protected Struct getStructInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Struct) rowData.get(columnIndex); } @Override protected boolean[] getBooleanArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); @SuppressWarnings("unchecked") // We know ARRAY produces a List. List values = (List) rowData.get(columnIndex); boolean[] r = new boolean[values.size()]; @@ -514,44 +619,52 @@ protected boolean[] getBooleanArrayInternal(int columnIndex) { @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBooleanListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override protected long[] getLongArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); return getLongListInternal(columnIndex).toPrimitiveArray(columnIndex); } @Override protected Int64Array getLongListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Int64Array) rowData.get(columnIndex); } @Override protected double[] getDoubleArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); return getDoubleListInternal(columnIndex).toPrimitiveArray(columnIndex); } @Override protected Float64Array getDoubleListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Float64Array) rowData.get(columnIndex); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBigDecimalListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (List) rowData.get(columnIndex); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getStringListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getJsonListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @@ -562,6 +675,7 @@ protected List getProtoMessageListInternal( Preconditions.checkNotNull( message, "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); + ensureDecoded(columnIndex); List bytesArray = (List) rowData.get(columnIndex); @@ -596,6 +710,7 @@ protected List getProtoEnumListInternal( int columnIndex, Function method) { Preconditions.checkNotNull( method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); + ensureDecoded(columnIndex); List enumIntArray = (List) rowData.get(columnIndex); List protoEnumList = new ArrayList<>(enumIntArray.size()); @@ -613,12 +728,14 @@ protected List getProtoEnumListInternal( @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getPgJsonbListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBytesListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Lists.transform( (List) rowData.get(columnIndex), l -> l == null ? null : l.getByteArray()); } @@ -626,18 +743,21 @@ protected List getBytesListInternal(int columnIndex) { @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getTimestampListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getDateListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY> produces a List. protected List getStructListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index 57feabbfcca..76d0f24225a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -243,6 +243,10 @@ public static ReadAndQueryOption directedRead(DirectedReadOptions directedReadOp return new DirectedReadOption(directedReadOptions); } + public static ReadAndQueryOption decodeMode(DecodeMode decodeMode) { + return new DecodeOption(decodeMode); + } + /** Option to request {@link CommitStats} for read/write transactions. */ static final class CommitStatsOption extends InternalOption implements TransactionOption { @Override @@ -374,6 +378,19 @@ void appendToOptions(Options options) { } } + static final class DecodeOption extends InternalOption implements ReadAndQueryOption { + private final DecodeMode decodeMode; + + DecodeOption(DecodeMode decodeMode) { + this.decodeMode = Preconditions.checkNotNull(decodeMode, "DecodeMode cannot be null"); + } + + @Override + void appendToOptions(Options options) { + options.decodeMode = decodeMode; + } + } + private boolean withCommitStats; private Duration maxCommitDelay; @@ -391,6 +408,7 @@ void appendToOptions(Options options) { private Boolean withOptimisticLock; private Boolean dataBoostEnabled; private DirectedReadOptions directedReadOptions; + private DecodeMode decodeMode; // Construction is via factory methods below. private Options() {} @@ -507,6 +525,14 @@ DirectedReadOptions directedReadOptions() { return directedReadOptions; } + boolean hasDecodeMode() { + return decodeMode != null; + } + + DecodeMode decodeMode() { + return decodeMode; + } + @Override public String toString() { StringBuilder b = new StringBuilder(); @@ -552,6 +578,9 @@ public String toString() { if (directedReadOptions != null) { b.append("directedReadOptions: ").append(directedReadOptions).append(' '); } + if (decodeMode != null) { + b.append("decodeMode: ").append(decodeMode).append(' '); + } return b.toString(); } @@ -640,6 +669,9 @@ public int hashCode() { if (directedReadOptions != null) { result = 31 * result + directedReadOptions.hashCode(); } + if (decodeMode != null) { + result = 31 * result + decodeMode.hashCode(); + } return result; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java new file mode 100644 index 00000000000..bbd8c41291f --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.InternalApi; +import com.google.protobuf.Value; + +/** Interface for {@link ResultSet}s that can return a protobuf value. */ +@InternalApi +public interface ProtobufResultSet extends ResultSet { + + /** Returns true if the protobuf value for the given column is still available. */ + boolean canGetProtobufValue(int columnIndex); + + /** + * Returns the column value as a protobuf value. + * + *

This is an internal method not intended for external usage. + * + *

This method may only be called before the column value has been decoded to a plain Java + * object. This means that the {@link DecodeMode} that is used for the {@link ResultSet} must be + * one of {@link DecodeMode#LAZY_PER_ROW} and {@link DecodeMode#LAZY_PER_COL}, and that the + * corresponding {@link ResultSet#getValue(int)}, {@link ResultSet#getBoolean(int)}, ... method + * may not yet have been called for the column. + */ + @InternalApi + Value getProtobufValue(int columnIndex); +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java index d55d4091b9f..a6cc7c729e5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java @@ -109,7 +109,7 @@ public ResultSet get() { } } - private static class PrePopulatedResultSet implements ResultSet { + private static class PrePopulatedResultSet implements ProtobufResultSet { private final List rows; private final Type type; private int index = -1; @@ -137,6 +137,19 @@ public boolean next() throws SpannerException { return ++index < rows.size(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed && index >= 0 && index < rows.size(); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + Preconditions.checkState(!closed, "ResultSet is closed"); + Preconditions.checkState(index >= 0, "Must be preceded by a next() call"); + Preconditions.checkElementIndex(index, rows.size(), "All rows have been yielded"); + return getValue(columnIndex).toProto(); + } + @Override public Struct getCurrentRowAsStruct() { Preconditions.checkState(!closed, "ResultSet is closed"); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 29928f61cec..81b00001105 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -263,6 +263,7 @@ public ReadContext singleUse(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -284,6 +285,7 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -305,6 +307,7 @@ public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -423,6 +426,7 @@ TransactionContextImpl newTransaction(Options options) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setSpan(currentSpan) .setTracer(tracer) .setExecutorProvider(spanner.getAsyncExecutorProvider()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 326a51d803e..8fe06f76cc8 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -151,6 +151,10 @@ int getDefaultPrefetchChunks() { return getOptions().getPrefetchChunks(); } + DecodeMode getDefaultDecodeMode() { + return getOptions().getDecodeMode(); + } + /** Returns the default query options that should be used for the specified database. */ QueryOptions getDefaultQueryOptions(DatabaseId databaseId) { return getOptions().getDefaultQueryOptions(databaseId); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index 9c6044aa938..a16be179ce3 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -113,6 +113,7 @@ public class SpannerOptions extends ServiceOptions { private final GrpcInterceptorProvider interceptorProvider; private final SessionPoolOptions sessionPoolOptions; private final int prefetchChunks; + private final DecodeMode decodeMode; private final int numChannels; private final String transportChannelExecutorThreadNameFormat; private final String databaseRole; @@ -616,6 +617,7 @@ protected SpannerOptions(Builder builder) { ? builder.sessionPoolOptions : SessionPoolOptions.newBuilder().build(); prefetchChunks = builder.prefetchChunks; + decodeMode = builder.decodeMode; databaseRole = builder.databaseRole; sessionLabels = builder.sessionLabels; try { @@ -704,6 +706,9 @@ public static class Builder extends ServiceOptions.Builder { static final int DEFAULT_PREFETCH_CHUNKS = 4; static final QueryOptions DEFAULT_QUERY_OPTIONS = QueryOptions.getDefaultInstance(); + // TODO: Set the default to DecodeMode.DIRECT before merging to keep the current default. + // It is currently set to LAZY_PER_COL so it is used in all tests. + static final DecodeMode DEFAULT_DECODE_MODE = DecodeMode.LAZY_PER_COL; static final RetrySettings DEFAULT_ADMIN_REQUESTS_LIMIT_EXCEEDED_RETRY_SETTINGS = RetrySettings.newBuilder() .setInitialRetryDelay(Duration.ofSeconds(5L)) @@ -730,6 +735,7 @@ public static class Builder private String transportChannelExecutorThreadNameFormat = "Cloud-Spanner-TransportChannel-%d"; private int prefetchChunks = DEFAULT_PREFETCH_CHUNKS; + private DecodeMode decodeMode = DEFAULT_DECODE_MODE; private SessionPoolOptions sessionPoolOptions; private String databaseRole; private ImmutableMap sessionLabels; @@ -797,6 +803,7 @@ protected Builder() { options.transportChannelExecutorThreadNameFormat; this.sessionPoolOptions = options.sessionPoolOptions; this.prefetchChunks = options.prefetchChunks; + this.decodeMode = options.decodeMode; this.databaseRole = options.databaseRole; this.sessionLabels = options.sessionLabels; this.spannerStubSettingsBuilder = options.spannerStubSettings.toBuilder(); @@ -1224,6 +1231,15 @@ public Builder setPrefetchChunks(int prefetchChunks) { return this; } + /** + * Specifies how values that are returned from a query should be decoded and converted from + * protobuf values into plain Java objects. + */ + public Builder setDecodeMode(DecodeMode decodeMode) { + this.decodeMode = decodeMode; + return this; + } + @Override public Builder setHost(String host) { super.setHost(host); @@ -1568,6 +1584,10 @@ public int getPrefetchChunks() { return prefetchChunks; } + public DecodeMode getDecodeMode() { + return decodeMode; + } + public static GrpcTransportOptions getDefaultGrpcTransportOptions() { return GrpcTransportOptions.newBuilder().build(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java index a845eb118bf..3f0155e4a5e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java @@ -1518,6 +1518,10 @@ void valueToString(StringBuilder b) { @Override boolean valueEquals(Value v) { + // NaN == NaN always returns false, so we need a custom check. + if (Double.isNaN(this.value)) { + return Double.isNaN(((Float64Impl) v).value); + } return ((Float64Impl) v).value == value; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java index dc373cf03bd..c642d7e505a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java @@ -16,28 +16,28 @@ package com.google.cloud.spanner.connection; -import com.google.cloud.ByteArray; -import com.google.cloud.Date; -import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbortedException; +import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.Options.QueryOption; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; -import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.Code; +import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.ReadWriteTransaction.RetriableStatement; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.hash.Funnel; import com.google.common.hash.HashCode; -import com.google.common.hash.HashFunction; -import com.google.common.hash.Hasher; -import com.google.common.hash.Hashing; -import com.google.common.hash.PrimitiveSink; -import java.math.BigDecimal; -import java.util.Objects; +import com.google.protobuf.Value; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicLong; @@ -71,11 +71,11 @@ class ChecksumResultSet extends ReplaceableForwardingResultSet implements Retria private final ParsedStatement statement; private final AnalyzeMode analyzeMode; private final QueryOption[] options; - private final ChecksumResultSet.ChecksumCalculator checksumCalculator = new ChecksumCalculator(); + private final ChecksumCalculator checksumCalculator = new ChecksumCalculator(); ChecksumResultSet( ReadWriteTransaction transaction, - ResultSet delegate, + ProtobufResultSet delegate, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { @@ -91,6 +91,13 @@ class ChecksumResultSet extends ReplaceableForwardingResultSet implements Retria this.options = options; } + @Override + public Value getProtobufValue(int columnIndex) { + // We can safely cast to ProtobufResultSet here, as the constructor of this class only accepts + // instances of ProtobufResultSet. + return ((ProtobufResultSet) getDelegate()).getProtobufValue(columnIndex); + } + /** Simple {@link Callable} for calling {@link ResultSet#next()} */ private final class NextCallable implements Callable { @Override @@ -102,7 +109,7 @@ public Boolean call() { boolean res = ChecksumResultSet.super.next(); // Only update the checksum if there was another row to be consumed. if (res) { - checksumCalculator.calculateNextChecksum(getCurrentRowAsStruct()); + checksumCalculator.calculateNextChecksum(ChecksumResultSet.this); } numberOfNextCalls.incrementAndGet(); return res; @@ -118,8 +125,9 @@ public boolean next() { } @VisibleForTesting - HashCode getChecksum() { - // HashCode is immutable and can be safely returned. + byte[] getChecksum() { + // Getting the checksum from the checksumCalculator will create a clone of the current digest + // and return the checksum from the clone, so it is safe to return this value. return checksumCalculator.getChecksum(); } @@ -132,8 +140,8 @@ HashCode getChecksum() { @Override public void retry(AbortedException aborted) throws AbortedException { // Execute the same query and consume the result set to the same point as the original. - ChecksumResultSet.ChecksumCalculator newChecksumCalculator = new ChecksumCalculator(); - ResultSet resultSet = null; + ChecksumCalculator newChecksumCalculator = new ChecksumCalculator(); + ProtobufResultSet resultSet = null; long counter = 0L; try { transaction @@ -150,7 +158,7 @@ public void retry(AbortedException aborted) throws AbortedException { statement, StatementExecutionStep.RETRY_NEXT_ON_RESULT_SET, transaction); next = resultSet.next(); if (next) { - newChecksumCalculator.calculateNextChecksum(resultSet.getCurrentRowAsStruct()); + newChecksumCalculator.calculateNextChecksum(resultSet); } counter++; } @@ -168,9 +176,9 @@ public void retry(AbortedException aborted) throws AbortedException { throw e; } // Check that we have the same number of rows and the same checksum. - HashCode newChecksum = newChecksumCalculator.getChecksum(); - HashCode currentChecksum = checksumCalculator.getChecksum(); - if (counter == numberOfNextCalls.get() && Objects.equals(newChecksum, currentChecksum)) { + byte[] newChecksum = newChecksumCalculator.getChecksum(); + byte[] currentChecksum = checksumCalculator.getChecksum(); + if (counter == numberOfNextCalls.get() && Arrays.equals(newChecksum, currentChecksum)) { // Checksum is ok, we only need to replace the delegate result set if it's still open. if (isClosed()) { resultSet.close(); @@ -184,222 +192,165 @@ public void retry(AbortedException aborted) throws AbortedException { } } - /** Calculates and keeps the current checksum of a {@link ChecksumResultSet} */ + /** + * Calculates a running checksum for all the data that has been seen sofar in this result set. The + * calculation is performed on the protobuf values that were returned by Cloud Spanner, which + * means that no decoding of the results is needed (or allowed!) before calculating the checksum. + * This is more efficient, both in terms of CPU usage and memory consumption, especially if the + * consumer of the result set does not read all values, or is only reading the underlying protobuf + * values. + */ private static final class ChecksumCalculator { - private static final HashFunction SHA256_FUNCTION = Hashing.sha256(); - private HashCode currentChecksum; + // Use a buffer of max 1Mb to hash string data. This means that strings of up to 1Mb in size + // will be hashed in one go, while strings larger than 1Mb will be chunked into pieces of at + // most 1Mb and then fed into the digest. The digest internally creates a copy of the string + // that is being hashed, so chunking large strings prevents them from being loaded into memory + // twice. + private static final int MAX_BUFFER_SIZE = 1 << 20; - private void calculateNextChecksum(Struct row) { - Hasher hasher = SHA256_FUNCTION.newHasher(); - if (currentChecksum != null) { - hasher.putBytes(currentChecksum.asBytes()); + private boolean firstRow = true; + private final MessageDigest digest; + private ByteBuffer buffer; + private ByteBuffer float64Buffer; + + ChecksumCalculator() { + try { + // This is safe, as all Java implementations are required to have MD5 implemented. + // See https://docs.oracle.com/javase/8/docs/api/java/security/MessageDigest.html + // MD5 requires less CPU power than SHA-256, and still offers a low enough collision + // probability for the use case at hand here. + digest = MessageDigest.getInstance("MD5"); + } catch (Throwable t) { + throw SpannerExceptionFactory.asSpannerException(t); } - hasher.putObject(row, StructFunnel.INSTANCE); - currentChecksum = hasher.hash(); } - private HashCode getChecksum() { - return currentChecksum; + private byte[] getChecksum() { + try { + // This is safe, as the MD5 MessageDigest is known to be cloneable. + MessageDigest clone = (MessageDigest) digest.clone(); + return clone.digest(); + } catch (CloneNotSupportedException e) { + throw SpannerExceptionFactory.asSpannerException(e); + } } - } - /** - * A {@link Funnel} implementation for calculating a {@link HashCode} for each row in a {@link - * ResultSet}. - */ - private enum StructFunnel implements Funnel { - INSTANCE; - private static final String NULL = "null"; - - @Override - public void funnel(Struct row, PrimitiveSink into) { - for (int i = 0; i < row.getColumnCount(); i++) { - if (row.isNull(i)) { - funnelValue(Code.STRING, null, into); + private void calculateNextChecksum(ProtobufResultSet resultSet) { + if (firstRow) { + for (StructField field : resultSet.getType().getStructFields()) { + digest.update(field.getType().toString().getBytes(StandardCharsets.UTF_8)); + } + } + for (int col = 0; col < resultSet.getColumnCount(); col++) { + Type type = resultSet.getColumnType(col); + if (resultSet.canGetProtobufValue(col)) { + Value value = resultSet.getProtobufValue(col); + digest.update((byte) value.getKindCase().getNumber()); + pushValue(type, value); } else { - Code type = row.getColumnType(i).getCode(); - switch (type) { - case ARRAY: - funnelArray(row.getColumnType(i).getArrayElementType().getCode(), row, i, into); - break; - case BOOL: - funnelValue(type, row.getBoolean(i), into); - break; - case BYTES: - case PROTO: - funnelValue(type, row.getBytes(i), into); - break; - case DATE: - funnelValue(type, row.getDate(i), into); - break; - case FLOAT64: - funnelValue(type, row.getDouble(i), into); - break; - case NUMERIC: - funnelValue(type, row.getBigDecimal(i), into); - break; - case PG_NUMERIC: - funnelValue(type, row.getString(i), into); - break; - case INT64: - case ENUM: - funnelValue(type, row.getLong(i), into); - break; - case STRING: - funnelValue(type, row.getString(i), into); - break; - case JSON: - funnelValue(type, row.getJson(i), into); - break; - case PG_JSONB: - funnelValue(type, row.getPgJsonb(i), into); - break; - case TIMESTAMP: - funnelValue(type, row.getTimestamp(i), into); - break; - - case STRUCT: - default: - throw new IllegalArgumentException("unsupported row type"); - } + // This will normally not happen, unless the user explicitly sets the decoding mode to + // DIRECT for a query in a read/write transaction. The default decoding mode in the + // Connection API is set to LAZY_PER_COL. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "Failed to get the underlying protobuf value for the column " + + resultSet.getMetadata().getRowType().getFields(col).getName() + + ". " + + "Executing queries with DecodeMode#DIRECT is not supported in read/write transactions."); } } + firstRow = false; } - private void funnelArray( - Code arrayElementType, Struct row, int columnIndex, PrimitiveSink into) { - funnelValue(Code.STRING, "BeginArray", into); - switch (arrayElementType) { - case BOOL: - into.putInt(row.getBooleanList(columnIndex).size()); - for (Boolean value : row.getBooleanList(columnIndex)) { - funnelValue(Code.BOOL, value, into); - } + private void pushValue(Type type, Value value) { + // Protobuf Value has a very limited set of possible types of values. All Cloud Spanner types + // are mapped to one of the protobuf values listed here, meaning that no changes are needed to + // this calculation when a new type is added to Cloud Spanner. + switch (value.getKindCase()) { + case NULL_VALUE: + // nothing needed, writing the KindCase is enough. break; - case BYTES: - case PROTO: - into.putInt(row.getBytesList(columnIndex).size()); - for (ByteArray value : row.getBytesList(columnIndex)) { - funnelValue(Code.BYTES, value, into); - } + case BOOL_VALUE: + digest.update(value.getBoolValue() ? (byte) 1 : 0); break; - case DATE: - into.putInt(row.getDateList(columnIndex).size()); - for (Date value : row.getDateList(columnIndex)) { - funnelValue(Code.DATE, value, into); - } + case STRING_VALUE: + putString(value.getStringValue()); break; - case FLOAT64: - into.putInt(row.getDoubleList(columnIndex).size()); - for (Double value : row.getDoubleList(columnIndex)) { - funnelValue(Code.FLOAT64, value, into); + case NUMBER_VALUE: + if (float64Buffer == null) { + // Create an 8-byte buffer that can be re-used for all float64 values in this result + // set. + float64Buffer = ByteBuffer.allocate(Double.BYTES); + } else { + float64Buffer.clear(); } + float64Buffer.putDouble(value.getNumberValue()); + float64Buffer.flip(); + digest.update(float64Buffer); break; - case NUMERIC: - into.putInt(row.getBigDecimalList(columnIndex).size()); - for (BigDecimal value : row.getBigDecimalList(columnIndex)) { - funnelValue(Code.NUMERIC, value, into); + case LIST_VALUE: + if (type.getCode() == Code.ARRAY) { + for (Value item : value.getListValue().getValuesList()) { + digest.update((byte) item.getKindCase().getNumber()); + pushValue(type.getArrayElementType(), item); + } + } else { + // This should not be possible. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "List values that are not an ARRAY are not supported"); } break; - case PG_NUMERIC: - into.putInt(row.getStringList(columnIndex).size()); - for (String value : row.getStringList(columnIndex)) { - funnelValue(Code.STRING, value, into); + case STRUCT_VALUE: + if (type.getCode() == Code.STRUCT) { + for (int col = 0; col < type.getStructFields().size(); col++) { + String name = type.getStructFields().get(col).getName(); + putString(name); + Value item = value.getStructValue().getFieldsMap().get(name); + digest.update((byte) item.getKindCase().getNumber()); + pushValue(type.getStructFields().get(col).getType(), item); + } + } else { + // This should not be possible. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "Struct values without a struct type are not supported"); } break; - case INT64: - case ENUM: - into.putInt(row.getLongList(columnIndex).size()); - for (Long value : row.getLongList(columnIndex)) { - funnelValue(Code.INT64, value, into); - } - break; - case STRING: - into.putInt(row.getStringList(columnIndex).size()); - for (String value : row.getStringList(columnIndex)) { - funnelValue(Code.STRING, value, into); - } - break; - case JSON: - into.putInt(row.getJsonList(columnIndex).size()); - for (String value : row.getJsonList(columnIndex)) { - funnelValue(Code.JSON, value, into); - } - break; - case PG_JSONB: - into.putInt(row.getPgJsonbList(columnIndex).size()); - for (String value : row.getPgJsonbList(columnIndex)) { - funnelValue(Code.PG_JSONB, value, into); - } - break; - case TIMESTAMP: - into.putInt(row.getTimestampList(columnIndex).size()); - for (Timestamp value : row.getTimestampList(columnIndex)) { - funnelValue(Code.TIMESTAMP, value, into); - } - break; - - case ARRAY: - case STRUCT: default: - throw new IllegalArgumentException("unsupported array element type"); + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.UNIMPLEMENTED, "Unsupported protobuf value: " + value.getKindCase()); } - funnelValue(Code.STRING, "EndArray", into); } - private void funnelValue(Code type, T value, PrimitiveSink into) { - // Include the type name in case the type of a column has changed. - into.putUnencodedChars(type.name()); - if (value == null) { - if (type == Code.BYTES || type == Code.STRING) { - // Put length -1 to distinguish from the string value 'null'. - into.putInt(-1); - } - into.putUnencodedChars(NULL); + /** Hashes a string value in blocks of max MAX_BUFFER_SIZE. */ + private void putString(String stringValue) { + int length = stringValue.length(); + if (buffer == null || (buffer.capacity() < MAX_BUFFER_SIZE && buffer.capacity() < length)) { + // Create a ByteBuffer with a maximum buffer size. + // This buffer is re-used for all string values in the result set. + buffer = ByteBuffer.allocate(Math.min(MAX_BUFFER_SIZE, length)); } else { - switch (type) { - case BOOL: - into.putBoolean((Boolean) value); - break; - case BYTES: - case PROTO: - ByteArray byteArray = (ByteArray) value; - into.putInt(byteArray.length()); - into.putBytes(byteArray.toByteArray()); - break; - case DATE: - Date date = (Date) value; - into.putInt(date.getYear()).putInt(date.getMonth()).putInt(date.getDayOfMonth()); - break; - case FLOAT64: - into.putDouble((Double) value); - break; - case NUMERIC: - String stringRepresentation = value.toString(); - into.putInt(stringRepresentation.length()); - into.putUnencodedChars(stringRepresentation); - break; - case INT64: - case ENUM: - into.putLong((Long) value); - break; - case PG_NUMERIC: - case STRING: - case JSON: - case PG_JSONB: - String stringValue = (String) value; - into.putInt(stringValue.length()); - into.putUnencodedChars(stringValue); - break; - case TIMESTAMP: - Timestamp timestamp = (Timestamp) value; - into.putLong(timestamp.getSeconds()).putInt(timestamp.getNanos()); - break; - case ARRAY: - case STRUCT: - default: - throw new IllegalArgumentException("invalid type for single value"); - } + buffer.clear(); + } + + // Wrap the string in a CharBuffer. This allows us to read from the string in sections without + // creating a new copy of (a part of) the string. E.g. using something like substring(..) + // would create a copy of that part of the string, using CharBuffer.wrap(..) does not. + CharBuffer source = CharBuffer.wrap(stringValue); + CharsetEncoder encoder = StandardCharsets.UTF_8.newEncoder(); + // source.hasRemaining() returns false when all the characters in the string have been + // processed. + while (source.hasRemaining()) { + // Encode the string into bytes and write them into the byte buffer. + // At most MAX_BUFFER_SIZE bytes will be written. + encoder.encode(source, buffer, false); + // Flip the buffer so we can read from the start. + buffer.flip(); + // Put the bytes from the buffer into the digest. + digest.update(buffer); + // Flip the buffer again, so we can repeat and write to the start of the buffer again. + buffer.flip(); } } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java index dff915e2cce..1b15ec50822 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java @@ -19,6 +19,7 @@ import com.google.cloud.ByteArray; import com.google.cloud.Date; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Struct; @@ -40,7 +41,7 @@ * to the actual query execution. It also ensures that any invalid query will throw an exception at * execution instead of the first next() call by a client. */ -class DirectExecuteResultSet implements ResultSet { +class DirectExecuteResultSet implements ProtobufResultSet { private static final String MISSING_NEXT_CALL = "Must be preceded by a next() call"; private final ResultSet delegate; private boolean nextCalledByClient = false; @@ -79,6 +80,21 @@ public boolean next() throws SpannerException { return initialNextResult; } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return nextCalledByClient + && delegate instanceof ProtobufResultSet + && ((ProtobufResultSet) delegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + Preconditions.checkState(nextCalledByClient, MISSING_NEXT_CALL); + Preconditions.checkState( + delegate instanceof ProtobufResultSet, "The result set does not support protobuf values"); + return ((ProtobufResultSet) delegate).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { Preconditions.checkState(nextCalledByClient, MISSING_NEXT_CALL); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java index e1fb87e4ade..6c4290c3b18 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java @@ -39,6 +39,7 @@ import com.google.cloud.spanner.Options.QueryOption; import com.google.cloud.spanner.Options.TransactionOption; import com.google.cloud.spanner.Options.UpdateOption; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ReadContext; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; @@ -427,7 +428,7 @@ public ApiFuture executeQueryAsync( statement, StatementExecutionStep.EXECUTE_STATEMENT, ReadWriteTransaction.this); - ResultSet delegate = + DirectExecuteResultSet delegate = DirectExecuteResultSet.ofResultSet( internalExecuteQuery(statement, analyzeMode, options)); return createAndAddRetryResultSet( @@ -797,7 +798,7 @@ T runWithRetry(Callable callable) throws SpannerException { * returns a retryable {@link ResultSet}. */ private ResultSet createAndAddRetryResultSet( - ResultSet resultSet, + ProtobufResultSet resultSet, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { @@ -1091,7 +1092,7 @@ interface RetriableStatement { /** Creates a {@link ChecksumResultSet} for this {@link ReadWriteTransaction}. */ @VisibleForTesting ChecksumResultSet createChecksumResultSet( - ResultSet delegate, + ProtobufResultSet delegate, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java index 7370551a46f..a8de14e5121 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java @@ -20,6 +20,7 @@ import com.google.cloud.Date; import com.google.cloud.Timestamp; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; @@ -42,7 +43,7 @@ * that is fetched using the new transaction. This is achieved by wrapping the returned result sets * in a {@link ReplaceableForwardingResultSet} that replaces its delegate after a transaction retry. */ -class ReplaceableForwardingResultSet implements ResultSet { +class ReplaceableForwardingResultSet implements ProtobufResultSet { private ResultSet delegate; private boolean closed; @@ -60,6 +61,10 @@ void replaceDelegate(ResultSet delegate) { this.delegate = delegate; } + protected ResultSet getDelegate() { + return this.delegate; + } + private void checkClosed() { if (closed) { throw SpannerExceptionFactory.newSpannerException( @@ -77,6 +82,21 @@ public boolean next() throws SpannerException { return delegate.next(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed + && delegate instanceof ProtobufResultSet + && ((ProtobufResultSet) delegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + checkClosed(); + Preconditions.checkState( + delegate instanceof ProtobufResultSet, "The result set does not support protobuf values"); + return ((ProtobufResultSet) getDelegate()).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { checkClosed(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java index 2a5a805c2c7..da8da78d92e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java @@ -17,6 +17,7 @@ package com.google.cloud.spanner.connection; import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.DecodeMode; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SessionPoolOptions; import com.google.cloud.spanner.Spanner; @@ -342,6 +343,9 @@ Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) { .setClientLibToken(MoreObjects.firstNonNull(key.userAgent, CONNECTION_API_CLIENT_LIB_TOKEN)) .setHost(key.host) .setProjectId(key.projectId) + // Use lazy decoding, so we can use the protobuf values for calculating the checksum that is + // needed for read/write transactions. + .setDecodeMode(DecodeMode.LAZY_PER_COL) .setDatabaseRole(options.getDatabaseRole()) .setCredentials(options.getCredentials()); builder.setSessionPoolOption(key.sessionPoolOptions); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java new file mode 100644 index 00000000000..6a6125e1dda --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.DecodeMode; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.Options; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Statement; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class DecodeModeTest extends AbstractMockServerTest { + + @After + public void clearRequests() { + mockSpanner.clearRequests(); + } + + @Test + public void testAllDecodeModes() { + int numRows = 10; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + String sql = "select * from random"; + Statement statement = Statement.of(sql); + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.query(statement, generator.generate())); + + try (Connection connection = createConnection()) { + for (boolean readonly : new boolean[] {true, false}) { + for (boolean autocommit : new boolean[] {true, false}) { + connection.setReadOnly(readonly); + connection.setAutocommit(autocommit); + + int receivedRows = 0; + // DecodeMode#DIRECT is not supported in read/write transactions, as the protobuf value is + // used for checksum calculation. + try (ResultSet direct = + connection.executeQuery( + statement, + !readonly && !autocommit + ? Options.decodeMode(DecodeMode.LAZY_PER_ROW) + : Options.decodeMode(DecodeMode.DIRECT)); + ResultSet lazyPerRow = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_ROW)); + ResultSet lazyPerCol = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_COL))) { + while (direct.next() && lazyPerRow.next() && lazyPerCol.next()) { + assertEquals(direct.getColumnCount(), lazyPerRow.getColumnCount()); + assertEquals(direct.getColumnCount(), lazyPerCol.getColumnCount()); + for (int col = 0; col < direct.getColumnCount(); col++) { + // Test getting the entire row as a struct both as the first thing we do, and as the + // last thing we do. This ensures that the method works as expected both when a row + // is lazily decoded by this method, and when it has already been decoded by another + // method. + if (col % 2 == 0) { + assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct()); + assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct()); + } + assertEquals(direct.isNull(col), lazyPerRow.isNull(col)); + assertEquals(direct.isNull(col), lazyPerCol.isNull(col)); + assertEquals(direct.getValue(col), lazyPerRow.getValue(col)); + assertEquals(direct.getValue(col), lazyPerCol.getValue(col)); + if (col % 2 == 1) { + assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct()); + assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct()); + } + } + receivedRows++; + } + assertEquals(numRows, receivedRows); + } + if (!autocommit) { + connection.commit(); + } + } + } + } + } + + @Test + public void testDecodeModeDirect_failsInReadWriteTransaction() { + int numRows = 1; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + String sql = "select * from random"; + Statement statement = Statement.of(sql); + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.query(statement, generator.generate())); + + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + try (ResultSet resultSet = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.DIRECT))) { + SpannerException exception = assertThrows(SpannerException.class, resultSet::next); + assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode()); + assertTrue( + exception.getMessage(), + exception + .getMessage() + .contains( + "Executing queries with DecodeMode#DIRECT is not supported in read/write transactions.")); + } + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java index 1e4f96d1568..b14f837ff7b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java @@ -59,6 +59,7 @@ public void testMethodCallBeforeNext() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -79,6 +80,7 @@ public void testMethodCallAfterClose() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -101,6 +103,7 @@ public void testMethodCallAfterNextHasReturnedFalse() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java index 3091364e17a..2067d36b5ea 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java @@ -34,6 +34,7 @@ import com.google.spanner.v1.TypeAnnotationCode; import com.google.spanner.v1.TypeCode; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -239,7 +240,10 @@ private void setRandomValue(Value.Builder builder, Type type) { if (dialect == Dialect.POSTGRESQL && randomNaN()) { builder.setStringValue("NaN"); } else { - builder.setStringValue(BigDecimal.valueOf(random.nextDouble()).toString()); + builder.setStringValue( + BigDecimal.valueOf(random.nextDouble()) + .setScale(9, RoundingMode.HALF_UP) + .toString()); } break; case INT64: diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java index 0f083fd1e50..8e643cf6e24 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java @@ -37,6 +37,7 @@ import com.google.cloud.spanner.CommitResponse; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.ResultSets; @@ -518,193 +519,197 @@ public void testChecksumResultSet() { .setGenre(Genre.FOLK) .build(); ProtocolMessageEnum protoEnumVal = Genre.ROCK; - ResultSet delegate1 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(BigDecimal.valueOf(550, 2)) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(BigDecimal.valueOf(750, 2)) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build())); + ProtobufResultSet delegate1 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(BigDecimal.valueOf(550, 2)) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(BigDecimal.valueOf(750, 2)) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build())); ChecksumResultSet rs1 = transaction.createChecksumResultSet(delegate1, parsedStatement, AnalyzeMode.NONE); - ResultSet delegate2 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build())); + ProtobufResultSet delegate2 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build())); ChecksumResultSet rs2 = transaction.createChecksumResultSet(delegate2, parsedStatement, AnalyzeMode.NONE); // rs1 and rs2 are equal, rs3 contains the same rows, but in a different order - ResultSet delegate3 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build(), - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build())); + ProtobufResultSet delegate3 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build(), + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build())); ChecksumResultSet rs3 = transaction.createChecksumResultSet(delegate3, parsedStatement, AnalyzeMode.NONE); // rs4 contains the same rows as rs1 and rs2, but also an additional row - ResultSet delegate4 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build(), - Struct.newBuilder() - .set("ID") - .to(3L) - .set("NAME") - .to("TEST 3") - .set("AMOUNT") - .to(new BigDecimal("9.99")) - .set("JSON") - .to(Value.json(emptyArrayJson)) - .set("PROTO") - .to(null, SingerInfo.getDescriptor()) - .set("PROTOENUM") - .to(Genre.POP) - .build())); + ProtobufResultSet delegate4 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build(), + Struct.newBuilder() + .set("ID") + .to(3L) + .set("NAME") + .to("TEST 3") + .set("AMOUNT") + .to(new BigDecimal("9.99")) + .set("JSON") + .to(Value.json(emptyArrayJson)) + .set("PROTO") + .to(null, SingerInfo.getDescriptor()) + .set("PROTOENUM") + .to(Genre.POP) + .build())); ChecksumResultSet rs4 = transaction.createChecksumResultSet(delegate4, parsedStatement, AnalyzeMode.NONE); @@ -736,44 +741,46 @@ public void testChecksumResultSetWithArray() { ParsedStatement parsedStatement = mock(ParsedStatement.class); Statement statement = Statement.of("SELECT * FROM FOO"); when(parsedStatement.getStatement()).thenReturn(statement); - ResultSet delegate1 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("PRICES", Type.array(Type.int64()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("PRICES") - .toInt64Array(new long[] {1L, 2L}) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("PRICES") - .toInt64Array(new long[] {3L, 4L}) - .build())); + ProtobufResultSet delegate1 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("PRICES", Type.array(Type.int64()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("PRICES") + .toInt64Array(new long[] {1L, 2L}) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("PRICES") + .toInt64Array(new long[] {3L, 4L}) + .build())); ChecksumResultSet rs1 = transaction.createChecksumResultSet(delegate1, parsedStatement, AnalyzeMode.NONE); - ResultSet delegate2 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("PRICES", Type.array(Type.int64()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("PRICES") - .toInt64Array(new long[] {1L, 2L}) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("PRICES") - .toInt64Array(new long[] {3L, 5L}) - .build())); + ProtobufResultSet delegate2 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("PRICES", Type.array(Type.int64()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("PRICES") + .toInt64Array(new long[] {1L, 2L}) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("PRICES") + .toInt64Array(new long[] {3L, 5L}) + .build())); ChecksumResultSet rs2 = transaction.createChecksumResultSet(delegate2, parsedStatement, AnalyzeMode.NONE); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java index bbb34675147..4617c47bc6b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java @@ -104,7 +104,14 @@ public void testReplace() { public void testMethodCallBeforeNext() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = - Arrays.asList("getStats", "getMetadata", "next", "close", "equals", "hashCode"); + Arrays.asList( + "canGetProtobufValue", + "getStats", + "getMetadata", + "next", + "close", + "equals", + "hashCode"); ReplaceableForwardingResultSet subject = createSubject(); // Test that all methods throw an IllegalStateException except the excluded methods when called // before a call to ResultSet#next(). @@ -116,6 +123,7 @@ public void testMethodCallAfterClose() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -140,6 +148,7 @@ public void testMethodCallAfterNextHasReturnedFalse() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next",