Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update grpc single-shot uploads to validate ack'd object size #2567

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

import com.google.api.core.SettableApiFuture;
import com.google.api.gax.grpc.GrpcCallContext;
import com.google.api.gax.rpc.ApiException;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.ClientStreamingCallable;
import com.google.cloud.storage.ChunkSegmenter.ChunkSegment;
import com.google.cloud.storage.Crc32cValue.Crc32cLengthKnown;
import com.google.cloud.storage.UnbufferedWritableByteChannelSession.UnbufferedWritableByteChannel;
import com.google.cloud.storage.WriteCtx.SimpleWriteObjectRequestBuilderFactory;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.storage.v2.ChecksummedData;
import com.google.storage.v2.ObjectChecksums;
Expand All @@ -34,11 +36,7 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.function.LongConsumer;
import org.checkerframework.checker.nullness.qual.NonNull;

final class GapicUnbufferedDirectWritableByteChannel implements UnbufferedWritableByteChannel {
Expand All @@ -55,34 +53,78 @@ final class GapicUnbufferedDirectWritableByteChannel implements UnbufferedWritab
private boolean open = true;
private boolean first = true;
private boolean finished = false;
private volatile WriteObjectRequest lastWrittenRequest;

GapicUnbufferedDirectWritableByteChannel(
SettableApiFuture<WriteObjectResponse> resultFuture,
ChunkSegmenter chunkSegmenter,
ClientStreamingCallable<WriteObjectRequest, WriteObjectResponse> write,
SimpleWriteObjectRequestBuilderFactory requestFactory) {
String bucketName = requestFactory.bucketName();
WriteCtx<SimpleWriteObjectRequestBuilderFactory> writeCtx) {
String bucketName = writeCtx.getRequestFactory().bucketName();
this.resultFuture = resultFuture;
this.chunkSegmenter = chunkSegmenter;

GrpcCallContext internalContext =
contextWithBucketName(bucketName, GrpcCallContext.createDefault());
this.write = write.withDefaultCallContext(internalContext);

this.writeCtx = new WriteCtx<>(requestFactory);
this.responseObserver = new Observer(writeCtx.getConfirmedBytes()::set, resultFuture::set);
this.writeCtx = writeCtx;
this.responseObserver = new Observer(internalContext);
}

@Override
public long write(ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException {
return internalWrite(srcs, srcsOffset, srcsLength, false);
}
if (!open) {
throw new ClosedChannelException();
}

@Override
public long writeAndClose(ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException {
long write = internalWrite(srcs, srcsOffset, srcsLength, true);
close();
return write;
ChunkSegment[] data = chunkSegmenter.segmentBuffers(srcs, srcsOffset, srcsLength);
if (data.length == 0) {
return 0;
}

try {
ApiStreamObserver<WriteObjectRequest> openedStream = openedStream();
int bytesConsumed = 0;
for (ChunkSegment datum : data) {
Crc32cLengthKnown crc32c = datum.getCrc32c();
ByteString b = datum.getB();
int contentSize = b.size();
long offset = writeCtx.getTotalSentBytes().getAndAdd(contentSize);
Crc32cLengthKnown cumulative =
writeCtx
.getCumulativeCrc32c()
.accumulateAndGet(crc32c, chunkSegmenter.getHasher()::nullSafeConcat);
ChecksummedData.Builder checksummedData = ChecksummedData.newBuilder().setContent(b);
if (crc32c != null) {
checksummedData.setCrc32C(crc32c.getValue());
}
WriteObjectRequest.Builder builder = writeCtx.newRequestBuilder();
if (!first) {
builder.clearWriteObjectSpec();
builder.clearObjectChecksums();
}
builder.setWriteOffset(offset).setChecksummedData(checksummedData.build());
if (!datum.isOnlyFullBlocks()) {
builder.setFinishWrite(true);
if (cumulative != null) {
builder.setObjectChecksums(
ObjectChecksums.newBuilder().setCrc32C(cumulative.getValue()).build());
}
finished = true;
}

WriteObjectRequest build = builder.build();
first = false;
bytesConsumed += contentSize;
lastWrittenRequest = build;
openedStream.onNext(build);
}
return bytesConsumed;
} catch (RuntimeException e) {
resultFuture.setException(e);
throw e;
}
}

@Override
Expand All @@ -95,6 +137,7 @@ public void close() throws IOException {
ApiStreamObserver<WriteObjectRequest> openedStream = openedStream();
if (!finished) {
WriteObjectRequest message = finishMessage();
lastWrittenRequest = message;
try {
openedStream.onNext(message);
openedStream.onCompleted();
Expand All @@ -115,79 +158,22 @@ public void close() throws IOException {
responseObserver.await();
}

private long internalWrite(ByteBuffer[] srcs, int srcsOffset, int srcsLength, boolean finalize)
throws ClosedChannelException {
if (!open) {
throw new ClosedChannelException();
}

ChunkSegment[] data = chunkSegmenter.segmentBuffers(srcs, srcsOffset, srcsLength);

List<WriteObjectRequest> messages = new ArrayList<>();

ApiStreamObserver<WriteObjectRequest> openedStream = openedStream();
int bytesConsumed = 0;
for (ChunkSegment datum : data) {
Crc32cLengthKnown crc32c = datum.getCrc32c();
ByteString b = datum.getB();
int contentSize = b.size();
long offset = writeCtx.getTotalSentBytes().getAndAdd(contentSize);
Crc32cLengthKnown cumulative =
writeCtx
.getCumulativeCrc32c()
.accumulateAndGet(crc32c, chunkSegmenter.getHasher()::nullSafeConcat);
ChecksummedData.Builder checksummedData = ChecksummedData.newBuilder().setContent(b);
if (crc32c != null) {
checksummedData.setCrc32C(crc32c.getValue());
}
WriteObjectRequest.Builder builder =
writeCtx
.newRequestBuilder()
.setWriteOffset(offset)
.setChecksummedData(checksummedData.build());
if (!datum.isOnlyFullBlocks()) {
builder.setFinishWrite(true);
if (cumulative != null) {
builder.setObjectChecksums(
ObjectChecksums.newBuilder().setCrc32C(cumulative.getValue()).build());
}
finished = true;
}

WriteObjectRequest build = possiblyPairDownRequest(builder, first).build();
first = false;
messages.add(build);
bytesConsumed += contentSize;
}
if (finalize && !finished) {
messages.add(finishMessage());
finished = true;
}

try {
for (WriteObjectRequest message : messages) {
openedStream.onNext(message);
}
} catch (RuntimeException e) {
resultFuture.setException(e);
throw e;
}

return bytesConsumed;
}

@NonNull
private WriteObjectRequest finishMessage() {
long offset = writeCtx.getTotalSentBytes().get();
Crc32cLengthKnown crc32cValue = writeCtx.getCumulativeCrc32c().get();

WriteObjectRequest.Builder b =
writeCtx.newRequestBuilder().setFinishWrite(true).setWriteOffset(offset);
WriteObjectRequest.Builder b = writeCtx.newRequestBuilder();
if (!first) {
b.clearWriteObjectSpec();
b.clearObjectChecksums();
first = false;
}
b.setFinishWrite(true).setWriteOffset(offset);
if (crc32cValue != null) {
b.setObjectChecksums(ObjectChecksums.newBuilder().setCrc32C(crc32cValue.getValue()).build());
}
WriteObjectRequest message = possiblyPairDownRequest(b, first).build();
return message;
return b.build();
}

private ApiStreamObserver<WriteObjectRequest> openedStream() {
Expand All @@ -201,48 +187,20 @@ private ApiStreamObserver<WriteObjectRequest> openedStream() {
return stream;
}

/**
* Several fields of a WriteObjectRequest are only allowed on the "first" message sent to gcs,
* this utility method centralizes the logic necessary to clear those fields for use by subsequent
* messages.
*/
private static WriteObjectRequest.Builder possiblyPairDownRequest(
WriteObjectRequest.Builder b, boolean firstMessageOfStream) {
if (firstMessageOfStream && b.getWriteOffset() == 0) {
return b;
}
if (b.getWriteOffset() > 0) {
b.clearWriteObjectSpec();
}

if (b.getWriteOffset() > 0 && !b.getFinishWrite()) {
b.clearObjectChecksums();
}
return b;
}

static class Observer implements ApiStreamObserver<WriteObjectResponse> {
class Observer implements ApiStreamObserver<WriteObjectResponse> {

private final LongConsumer sizeCallback;
private final Consumer<WriteObjectResponse> completeCallback;
private final GrpcCallContext context;

private final SettableApiFuture<Void> invocationHandle;
private volatile WriteObjectResponse last;

Observer(LongConsumer sizeCallback, Consumer<WriteObjectResponse> completeCallback) {
this.sizeCallback = sizeCallback;
this.completeCallback = completeCallback;
Observer(GrpcCallContext context) {
this.context = context;
this.invocationHandle = SettableApiFuture.create();
}

@Override
public void onNext(WriteObjectResponse value) {
// incremental update
if (value.hasPersistedSize()) {
sizeCallback.accept(value.getPersistedSize());
} else if (value.hasResource()) {
sizeCallback.accept(value.getResource().getSize());
}
last = value;
}

Expand All @@ -257,15 +215,58 @@ public void onNext(WriteObjectResponse value) {
*/
@Override
public void onError(Throwable t) {
invocationHandle.setException(t);
if (t instanceof ApiException) {
// use StorageExceptions logic to translate from ApiException to our status codes ensuring
// things fall in line with our retry handlers.
// This is suboptimal, as it will initialize a second exception, however this is the
// unusual case, and it should not cause a significant overhead given its rarity.
StorageException tmp = StorageException.asStorageException((ApiException) t);
StorageException storageException =
ResumableSessionFailureScenario.toStorageException(
tmp.getCode(), tmp.getMessage(), tmp.getReason(), getRequests(), null, context, t);
invocationHandle.setException(storageException);
} else {
invocationHandle.setException(t);
}
}

@Override
public void onCompleted() {
if (last != null && last.hasResource()) {
completeCallback.accept(last);
try {
if (last == null) {
throw new StorageException(
0, "onComplete without preceding onNext, unable to determine success.");
} else if (last.hasResource()) {
long totalSentBytes = writeCtx.getTotalSentBytes().get();
long finalSize = last.getResource().getSize();
if (totalSentBytes == finalSize) {
writeCtx.getConfirmedBytes().set(finalSize);
resultFuture.set(last);
} else if (finalSize < totalSentBytes) {
throw ResumableSessionFailureScenario.SCENARIO_4_1.toStorageException(
getRequests(), last, context, null);
} else {
throw ResumableSessionFailureScenario.SCENARIO_4_2.toStorageException(
getRequests(), last, context, null);
}
} else {
throw ResumableSessionFailureScenario.SCENARIO_0.toStorageException(
getRequests(), last, context, null);
}
} catch (Throwable se) {
open = false;
invocationHandle.setException(se);
} finally {
invocationHandle.set(null);
}
}

private @NonNull ImmutableList<@NonNull WriteObjectRequest> getRequests() {
if (lastWrittenRequest == null) {
return ImmutableList.of();
} else {
return ImmutableList.of(lastWrittenRequest);
}
invocationHandle.set(null);
}

void await() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ UnbufferedWritableByteChannelSession<WriteObjectResponse> build() {
resultFuture,
getChunkSegmenter(),
write,
WriteObjectRequestBuilderFactory.simple(start)))
new WriteCtx<>(WriteObjectRequestBuilderFactory.simple(start))))
.andThen(StorageByteChannels.writable()::createSynchronized));
}
}
Expand Down Expand Up @@ -213,7 +213,7 @@ BufferedWritableByteChannelSession<WriteObjectResponse> build() {
resultFuture,
getChunkSegmenter(),
write,
WriteObjectRequestBuilderFactory.simple(start)))
new WriteCtx<>(WriteObjectRequestBuilderFactory.simple(start))))
.andThen(c -> new DefaultBufferedWritableByteChannel(bufferHandle, c))
.andThen(StorageByteChannels.writable()::createSynchronized));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ enum ResumableSessionFailureScenario {
SCENARIO_4_1(
BaseServiceException.UNKNOWN_CODE,
"dataLoss",
"Finalized resumable session, but object size less than expected."),
"Finalized upload, but object size less than expected."),
SCENARIO_4_2(
BaseServiceException.UNKNOWN_CODE,
"dataLoss",
"Finalized resumable session, but object size greater than expected."),
"Finalized upload, but object size greater than expected."),
SCENARIO_5(
BaseServiceException.UNKNOWN_CODE,
"dataLoss",
Expand Down
Loading
Loading