diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedDirectWritableByteChannel.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedDirectWritableByteChannel.java index 861e867db1..5e67440a7e 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedDirectWritableByteChannel.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedDirectWritableByteChannel.java @@ -20,12 +20,14 @@ import com.google.api.core.SettableApiFuture; import com.google.api.gax.grpc.GrpcCallContext; +import com.google.api.gax.rpc.ApiException; import com.google.api.gax.rpc.ApiStreamObserver; import com.google.api.gax.rpc.ClientStreamingCallable; import com.google.cloud.storage.ChunkSegmenter.ChunkSegment; import com.google.cloud.storage.Crc32cValue.Crc32cLengthKnown; import com.google.cloud.storage.UnbufferedWritableByteChannelSession.UnbufferedWritableByteChannel; import com.google.cloud.storage.WriteCtx.SimpleWriteObjectRequestBuilderFactory; +import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; import com.google.storage.v2.ChecksummedData; import com.google.storage.v2.ObjectChecksums; @@ -34,11 +36,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; -import java.util.ArrayList; -import java.util.List; import java.util.concurrent.ExecutionException; -import java.util.function.Consumer; -import java.util.function.LongConsumer; import org.checkerframework.checker.nullness.qual.NonNull; final class GapicUnbufferedDirectWritableByteChannel implements UnbufferedWritableByteChannel { @@ -55,13 +53,14 @@ final class GapicUnbufferedDirectWritableByteChannel implements UnbufferedWritab private boolean open = true; private boolean first = true; private boolean finished = false; + private volatile WriteObjectRequest lastWrittenRequest; GapicUnbufferedDirectWritableByteChannel( SettableApiFuture resultFuture, ChunkSegmenter chunkSegmenter, ClientStreamingCallable write, - SimpleWriteObjectRequestBuilderFactory requestFactory) { - String bucketName = requestFactory.bucketName(); + WriteCtx writeCtx) { + String bucketName = writeCtx.getRequestFactory().bucketName(); this.resultFuture = resultFuture; this.chunkSegmenter = chunkSegmenter; @@ -69,20 +68,63 @@ final class GapicUnbufferedDirectWritableByteChannel implements UnbufferedWritab contextWithBucketName(bucketName, GrpcCallContext.createDefault()); this.write = write.withDefaultCallContext(internalContext); - this.writeCtx = new WriteCtx<>(requestFactory); - this.responseObserver = new Observer(writeCtx.getConfirmedBytes()::set, resultFuture::set); + this.writeCtx = writeCtx; + this.responseObserver = new Observer(internalContext); } @Override public long write(ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException { - return internalWrite(srcs, srcsOffset, srcsLength, false); - } + if (!open) { + throw new ClosedChannelException(); + } - @Override - public long writeAndClose(ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException { - long write = internalWrite(srcs, srcsOffset, srcsLength, true); - close(); - return write; + ChunkSegment[] data = chunkSegmenter.segmentBuffers(srcs, srcsOffset, srcsLength); + if (data.length == 0) { + return 0; + } + + try { + ApiStreamObserver openedStream = openedStream(); + int bytesConsumed = 0; + for (ChunkSegment datum : data) { + Crc32cLengthKnown crc32c = datum.getCrc32c(); + ByteString b = datum.getB(); + int contentSize = b.size(); + long offset = writeCtx.getTotalSentBytes().getAndAdd(contentSize); + Crc32cLengthKnown cumulative = + writeCtx + .getCumulativeCrc32c() + .accumulateAndGet(crc32c, chunkSegmenter.getHasher()::nullSafeConcat); + ChecksummedData.Builder checksummedData = ChecksummedData.newBuilder().setContent(b); + if (crc32c != null) { + checksummedData.setCrc32C(crc32c.getValue()); + } + WriteObjectRequest.Builder builder = writeCtx.newRequestBuilder(); + if (!first) { + builder.clearWriteObjectSpec(); + builder.clearObjectChecksums(); + } + builder.setWriteOffset(offset).setChecksummedData(checksummedData.build()); + if (!datum.isOnlyFullBlocks()) { + builder.setFinishWrite(true); + if (cumulative != null) { + builder.setObjectChecksums( + ObjectChecksums.newBuilder().setCrc32C(cumulative.getValue()).build()); + } + finished = true; + } + + WriteObjectRequest build = builder.build(); + first = false; + bytesConsumed += contentSize; + lastWrittenRequest = build; + openedStream.onNext(build); + } + return bytesConsumed; + } catch (RuntimeException e) { + resultFuture.setException(e); + throw e; + } } @Override @@ -95,6 +137,7 @@ public void close() throws IOException { ApiStreamObserver openedStream = openedStream(); if (!finished) { WriteObjectRequest message = finishMessage(); + lastWrittenRequest = message; try { openedStream.onNext(message); openedStream.onCompleted(); @@ -115,79 +158,22 @@ public void close() throws IOException { responseObserver.await(); } - private long internalWrite(ByteBuffer[] srcs, int srcsOffset, int srcsLength, boolean finalize) - throws ClosedChannelException { - if (!open) { - throw new ClosedChannelException(); - } - - ChunkSegment[] data = chunkSegmenter.segmentBuffers(srcs, srcsOffset, srcsLength); - - List messages = new ArrayList<>(); - - ApiStreamObserver openedStream = openedStream(); - int bytesConsumed = 0; - for (ChunkSegment datum : data) { - Crc32cLengthKnown crc32c = datum.getCrc32c(); - ByteString b = datum.getB(); - int contentSize = b.size(); - long offset = writeCtx.getTotalSentBytes().getAndAdd(contentSize); - Crc32cLengthKnown cumulative = - writeCtx - .getCumulativeCrc32c() - .accumulateAndGet(crc32c, chunkSegmenter.getHasher()::nullSafeConcat); - ChecksummedData.Builder checksummedData = ChecksummedData.newBuilder().setContent(b); - if (crc32c != null) { - checksummedData.setCrc32C(crc32c.getValue()); - } - WriteObjectRequest.Builder builder = - writeCtx - .newRequestBuilder() - .setWriteOffset(offset) - .setChecksummedData(checksummedData.build()); - if (!datum.isOnlyFullBlocks()) { - builder.setFinishWrite(true); - if (cumulative != null) { - builder.setObjectChecksums( - ObjectChecksums.newBuilder().setCrc32C(cumulative.getValue()).build()); - } - finished = true; - } - - WriteObjectRequest build = possiblyPairDownRequest(builder, first).build(); - first = false; - messages.add(build); - bytesConsumed += contentSize; - } - if (finalize && !finished) { - messages.add(finishMessage()); - finished = true; - } - - try { - for (WriteObjectRequest message : messages) { - openedStream.onNext(message); - } - } catch (RuntimeException e) { - resultFuture.setException(e); - throw e; - } - - return bytesConsumed; - } - @NonNull private WriteObjectRequest finishMessage() { long offset = writeCtx.getTotalSentBytes().get(); Crc32cLengthKnown crc32cValue = writeCtx.getCumulativeCrc32c().get(); - WriteObjectRequest.Builder b = - writeCtx.newRequestBuilder().setFinishWrite(true).setWriteOffset(offset); + WriteObjectRequest.Builder b = writeCtx.newRequestBuilder(); + if (!first) { + b.clearWriteObjectSpec(); + b.clearObjectChecksums(); + first = false; + } + b.setFinishWrite(true).setWriteOffset(offset); if (crc32cValue != null) { b.setObjectChecksums(ObjectChecksums.newBuilder().setCrc32C(crc32cValue.getValue()).build()); } - WriteObjectRequest message = possiblyPairDownRequest(b, first).build(); - return message; + return b.build(); } private ApiStreamObserver openedStream() { @@ -201,48 +187,20 @@ private ApiStreamObserver openedStream() { return stream; } - /** - * Several fields of a WriteObjectRequest are only allowed on the "first" message sent to gcs, - * this utility method centralizes the logic necessary to clear those fields for use by subsequent - * messages. - */ - private static WriteObjectRequest.Builder possiblyPairDownRequest( - WriteObjectRequest.Builder b, boolean firstMessageOfStream) { - if (firstMessageOfStream && b.getWriteOffset() == 0) { - return b; - } - if (b.getWriteOffset() > 0) { - b.clearWriteObjectSpec(); - } - - if (b.getWriteOffset() > 0 && !b.getFinishWrite()) { - b.clearObjectChecksums(); - } - return b; - } - - static class Observer implements ApiStreamObserver { + class Observer implements ApiStreamObserver { - private final LongConsumer sizeCallback; - private final Consumer completeCallback; + private final GrpcCallContext context; private final SettableApiFuture invocationHandle; private volatile WriteObjectResponse last; - Observer(LongConsumer sizeCallback, Consumer completeCallback) { - this.sizeCallback = sizeCallback; - this.completeCallback = completeCallback; + Observer(GrpcCallContext context) { + this.context = context; this.invocationHandle = SettableApiFuture.create(); } @Override public void onNext(WriteObjectResponse value) { - // incremental update - if (value.hasPersistedSize()) { - sizeCallback.accept(value.getPersistedSize()); - } else if (value.hasResource()) { - sizeCallback.accept(value.getResource().getSize()); - } last = value; } @@ -257,15 +215,58 @@ public void onNext(WriteObjectResponse value) { */ @Override public void onError(Throwable t) { - invocationHandle.setException(t); + if (t instanceof ApiException) { + // use StorageExceptions logic to translate from ApiException to our status codes ensuring + // things fall in line with our retry handlers. + // This is suboptimal, as it will initialize a second exception, however this is the + // unusual case, and it should not cause a significant overhead given its rarity. + StorageException tmp = StorageException.asStorageException((ApiException) t); + StorageException storageException = + ResumableSessionFailureScenario.toStorageException( + tmp.getCode(), tmp.getMessage(), tmp.getReason(), getRequests(), null, context, t); + invocationHandle.setException(storageException); + } else { + invocationHandle.setException(t); + } } @Override public void onCompleted() { - if (last != null && last.hasResource()) { - completeCallback.accept(last); + try { + if (last == null) { + throw new StorageException( + 0, "onComplete without preceding onNext, unable to determine success."); + } else if (last.hasResource()) { + long totalSentBytes = writeCtx.getTotalSentBytes().get(); + long finalSize = last.getResource().getSize(); + if (totalSentBytes == finalSize) { + writeCtx.getConfirmedBytes().set(finalSize); + resultFuture.set(last); + } else if (finalSize < totalSentBytes) { + throw ResumableSessionFailureScenario.SCENARIO_4_1.toStorageException( + getRequests(), last, context, null); + } else { + throw ResumableSessionFailureScenario.SCENARIO_4_2.toStorageException( + getRequests(), last, context, null); + } + } else { + throw ResumableSessionFailureScenario.SCENARIO_0.toStorageException( + getRequests(), last, context, null); + } + } catch (Throwable se) { + open = false; + invocationHandle.setException(se); + } finally { + invocationHandle.set(null); + } + } + + private @NonNull ImmutableList<@NonNull WriteObjectRequest> getRequests() { + if (lastWrittenRequest == null) { + return ImmutableList.of(); + } else { + return ImmutableList.of(lastWrittenRequest); } - invocationHandle.set(null); } void await() { diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicWritableByteChannelSessionBuilder.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicWritableByteChannelSessionBuilder.java index 8854053322..f4a8afdb9d 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicWritableByteChannelSessionBuilder.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicWritableByteChannelSessionBuilder.java @@ -185,7 +185,7 @@ UnbufferedWritableByteChannelSession build() { resultFuture, getChunkSegmenter(), write, - WriteObjectRequestBuilderFactory.simple(start))) + new WriteCtx<>(WriteObjectRequestBuilderFactory.simple(start)))) .andThen(StorageByteChannels.writable()::createSynchronized)); } } @@ -213,7 +213,7 @@ BufferedWritableByteChannelSession build() { resultFuture, getChunkSegmenter(), write, - WriteObjectRequestBuilderFactory.simple(start))) + new WriteCtx<>(WriteObjectRequestBuilderFactory.simple(start)))) .andThen(c -> new DefaultBufferedWritableByteChannel(bufferHandle, c)) .andThen(StorageByteChannels.writable()::createSynchronized)); } diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/ResumableSessionFailureScenario.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/ResumableSessionFailureScenario.java index 294d481cdc..8c06d70dda 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/ResumableSessionFailureScenario.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/ResumableSessionFailureScenario.java @@ -64,11 +64,11 @@ enum ResumableSessionFailureScenario { SCENARIO_4_1( BaseServiceException.UNKNOWN_CODE, "dataLoss", - "Finalized resumable session, but object size less than expected."), + "Finalized upload, but object size less than expected."), SCENARIO_4_2( BaseServiceException.UNKNOWN_CODE, "dataLoss", - "Finalized resumable session, but object size greater than expected."), + "Finalized upload, but object size greater than expected."), SCENARIO_5( BaseServiceException.UNKNOWN_CODE, "dataLoss", diff --git a/google-cloud-storage/src/test/java/com/google/cloud/storage/ITGapicUnbufferedDirectWritableByteChannelTest.java b/google-cloud-storage/src/test/java/com/google/cloud/storage/ITGapicUnbufferedDirectWritableByteChannelTest.java new file mode 100644 index 0000000000..4a70e9a2e3 --- /dev/null +++ b/google-cloud-storage/src/test/java/com/google/cloud/storage/ITGapicUnbufferedDirectWritableByteChannelTest.java @@ -0,0 +1,161 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.storage; + +import static com.google.cloud.storage.ByteSizeConstants._256KiB; +import static com.google.cloud.storage.ByteSizeConstants._512KiB; +import static com.google.cloud.storage.ByteSizeConstants._768KiB; +import static com.google.cloud.storage.TestUtils.assertAll; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.api.core.SettableApiFuture; +import com.google.cloud.storage.ITGapicUnbufferedWritableByteChannelTest.DirectWriteService; +import com.google.cloud.storage.WriteCtx.SimpleWriteObjectRequestBuilderFactory; +import com.google.cloud.storage.WriteCtx.WriteObjectRequestBuilderFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.storage.v2.Object; +import com.google.storage.v2.StorageClient; +import com.google.storage.v2.WriteObjectRequest; +import com.google.storage.v2.WriteObjectResponse; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.Test; + +public final class ITGapicUnbufferedDirectWritableByteChannelTest { + + private static final ChunkSegmenter CHUNK_SEGMENTER = + new ChunkSegmenter(Hasher.noop(), ByteStringStrategy.copy(), _256KiB, _256KiB); + + /** Attempting to finalize, ack equals expected */ + @Test + public void ack_eq() throws Exception { + WriteObjectRequest req1 = + WriteObjectRequest.newBuilder().setWriteOffset(_256KiB).setFinishWrite(true).build(); + WriteObjectResponse resp1 = + WriteObjectResponse.newBuilder() + .setResource(Object.newBuilder().setName("name").setSize(_256KiB).build()) + .build(); + + ImmutableMap, WriteObjectResponse> map = + ImmutableMap.of(ImmutableList.of(req1), resp1); + DirectWriteService service1 = new DirectWriteService(map); + + try (FakeServer fakeServer = FakeServer.of(service1); + GrpcStorageImpl storage = + (GrpcStorageImpl) fakeServer.getGrpcStorageOptions().getService()) { + StorageClient storageClient = storage.storageClient; + + SettableApiFuture done = SettableApiFuture.create(); + WriteCtx writeCtx = + new WriteCtx<>(WriteObjectRequestBuilderFactory.simple(req1)); + writeCtx.getTotalSentBytes().set(_256KiB); + writeCtx.getConfirmedBytes().set(0); + + GapicUnbufferedDirectWritableByteChannel channel = + new GapicUnbufferedDirectWritableByteChannel( + done, CHUNK_SEGMENTER, storageClient.writeObjectCallable(), writeCtx); + + channel.close(); + + WriteObjectResponse writeObjectResponse = done.get(2, TimeUnit.SECONDS); + assertAll( + () -> assertThat(writeObjectResponse).isEqualTo(resp1), + () -> assertThat(writeCtx.getConfirmedBytes().get()).isEqualTo(_256KiB), + () -> assertThat(channel.isOpen()).isFalse()); + } + } + + /** Attempting to finalize, ack < expected */ + @Test + public void ack_lt() throws Exception { + WriteObjectRequest req1 = + WriteObjectRequest.newBuilder().setWriteOffset(_512KiB).setFinishWrite(true).build(); + WriteObjectResponse resp1 = + WriteObjectResponse.newBuilder() + .setResource(Object.newBuilder().setName("name").setSize(_256KiB).build()) + .build(); + + ImmutableMap, WriteObjectResponse> map = + ImmutableMap.of(ImmutableList.of(req1), resp1); + DirectWriteService service1 = new DirectWriteService(map); + + try (FakeServer fakeServer = FakeServer.of(service1); + GrpcStorageImpl storage = + (GrpcStorageImpl) fakeServer.getGrpcStorageOptions().getService()) { + StorageClient storageClient = storage.storageClient; + + SettableApiFuture done = SettableApiFuture.create(); + WriteCtx writeCtx = + new WriteCtx<>(WriteObjectRequestBuilderFactory.simple(req1)); + writeCtx.getTotalSentBytes().set(_512KiB); + writeCtx.getConfirmedBytes().set(0); + + //noinspection resource + GapicUnbufferedDirectWritableByteChannel channel = + new GapicUnbufferedDirectWritableByteChannel( + done, CHUNK_SEGMENTER, storageClient.writeObjectCallable(), writeCtx); + + StorageException se = assertThrows(StorageException.class, channel::close); + assertAll( + () -> assertThat(se.getCode()).isEqualTo(0), + () -> assertThat(se.getReason()).isEqualTo("dataLoss"), + () -> assertThat(writeCtx.getConfirmedBytes().get()).isEqualTo(0), + () -> assertThat(channel.isOpen()).isFalse()); + } + } + + /** Attempting to finalize, ack > expected */ + @Test + public void ack_gt() throws Exception { + WriteObjectRequest req1 = + WriteObjectRequest.newBuilder().setWriteOffset(_512KiB).setFinishWrite(true).build(); + WriteObjectResponse resp1 = + WriteObjectResponse.newBuilder() + .setResource(Object.newBuilder().setName("name").setSize(_768KiB).build()) + .build(); + + ImmutableMap, WriteObjectResponse> map = + ImmutableMap.of(ImmutableList.of(req1), resp1); + DirectWriteService service1 = new DirectWriteService(map); + + try (FakeServer fakeServer = FakeServer.of(service1); + GrpcStorageImpl storage = + (GrpcStorageImpl) fakeServer.getGrpcStorageOptions().getService()) { + StorageClient storageClient = storage.storageClient; + + SettableApiFuture done = SettableApiFuture.create(); + WriteCtx writeCtx = + new WriteCtx<>(WriteObjectRequestBuilderFactory.simple(req1)); + writeCtx.getTotalSentBytes().set(_512KiB); + writeCtx.getConfirmedBytes().set(0); + + //noinspection resource + GapicUnbufferedDirectWritableByteChannel channel = + new GapicUnbufferedDirectWritableByteChannel( + done, CHUNK_SEGMENTER, storageClient.writeObjectCallable(), writeCtx); + + StorageException se = assertThrows(StorageException.class, channel::close); + assertAll( + () -> assertThat(se.getCode()).isEqualTo(0), + () -> assertThat(se.getReason()).isEqualTo("dataLoss"), + () -> assertThat(writeCtx.getConfirmedBytes().get()).isEqualTo(0), + () -> assertThat(channel.isOpen()).isFalse()); + } + } +} diff --git a/google-cloud-storage/src/test/java/com/google/cloud/storage/ITGapicUnbufferedWritableByteChannelTest.java b/google-cloud-storage/src/test/java/com/google/cloud/storage/ITGapicUnbufferedWritableByteChannelTest.java index 239b9bf3e6..326cbb1566 100644 --- a/google-cloud-storage/src/test/java/com/google/cloud/storage/ITGapicUnbufferedWritableByteChannelTest.java +++ b/google-cloud-storage/src/test/java/com/google/cloud/storage/ITGapicUnbufferedWritableByteChannelTest.java @@ -153,11 +153,14 @@ public void directUpload() throws IOException, InterruptedException, ExecutionEx new DirectWriteService( ImmutableMap.of(ImmutableList.of(req1, req2, req3, req4, req5), resp)); try (FakeServer fake = FakeServer.of(service); - StorageClient sc = StorageClient.create(fake.storageSettings())) { + StorageClient sc = + PackagePrivateMethodWorkarounds.maybeGetStorageClient( + fake.getGrpcStorageOptions().getService())) { + assertThat(sc).isNotNull(); SettableApiFuture result = SettableApiFuture.create(); try (GapicUnbufferedDirectWritableByteChannel c = new GapicUnbufferedDirectWritableByteChannel( - result, segmenter, sc.writeObjectCallable(), reqFactory)) { + result, segmenter, sc.writeObjectCallable(), new WriteCtx<>(reqFactory))) { c.write(ByteBuffer.wrap(bytes)); } assertThat(result.get()).isEqualTo(resp);