Skip to content

Commit

Permalink
fix: update GapicUnbufferedChunkedResumableWritableByteChannel to be …
Browse files Browse the repository at this point in the history
…tolerant of non-quantum writes (#2537)

Update GapicUnbufferedChunkedResumableWritableByteChannel to only accept bytes at the 256KiB boundary when `#write(ByteBuffer[], int, int)` is called.

Calls to writeAndClose(ByteBuffer[], int, int) will consume all bytes.
  • Loading branch information
BenWhitehead authored May 21, 2024
1 parent 79d721d commit 1701fde
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import com.google.cloud.storage.Crc32cValue.Crc32cLengthKnown;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.math.IntMath;
import com.google.protobuf.ByteString;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.Deque;
Expand All @@ -40,6 +43,13 @@ final class ChunkSegmenter {

@VisibleForTesting
ChunkSegmenter(Hasher hasher, ByteStringStrategy bss, int maxSegmentSize, int blockSize) {
int mod = maxSegmentSize % blockSize;
Preconditions.checkArgument(
mod == 0,
"maxSegmentSize % blockSize == 0 (%s % %s == %s)",
maxSegmentSize,
blockSize,
mod);
this.hasher = hasher;
this.bss = bss;
this.maxSegmentSize = maxSegmentSize;
Expand Down Expand Up @@ -79,32 +89,92 @@ ChunkSegment[] segmentBuffers(ByteBuffer[] bbs) {
}

ChunkSegment[] segmentBuffers(ByteBuffer[] bbs, int offset, int length) {
return segmentBuffers(bbs, offset, length, true);
}

ChunkSegment[] segmentBuffers(
ByteBuffer[] bbs, int offset, int length, boolean allowUnalignedBlocks) {
// turn this into a single branch, rather than multiple that would need to be checked each
// element of the iteration
if (allowUnalignedBlocks) {
return segmentWithUnaligned(bbs, offset, length);
} else {
return segmentWithoutUnaligned(bbs, offset, length);
}
}

private ChunkSegment[] segmentWithUnaligned(ByteBuffer[] bbs, int offset, int length) {
Deque<ChunkSegment> data = new ArrayDeque<>();

for (int i = offset; i < length; i++) {
ByteBuffer buffer = bbs[i];
int remaining;
while ((remaining = buffer.remaining()) > 0) {
consumeBytes(data, remaining, buffer);
}
}

return data.toArray(new ChunkSegment[0]);
}

private ChunkSegment[] segmentWithoutUnaligned(ByteBuffer[] bbs, int offset, int length) {
Deque<ChunkSegment> data = new ArrayDeque<>();

final long totalRemaining = Buffers.totalRemaining(bbs, offset, length);
long consumedSoFar = 0;

int currentBlockPending = blockSize;

for (int i = offset; i < length; i++) {
ByteBuffer buffer = bbs[i];
int remaining;
while ((remaining = buffer.remaining()) > 0) {
// either no chunk or most recent chunk is full, start a new one
ChunkSegment peekLast = data.peekLast();
if (peekLast == null || peekLast.b.size() == maxSegmentSize) {
int limit = Math.min(remaining, maxSegmentSize);
ChunkSegment datum = newSegment(buffer, limit);
data.addLast(datum);
long overallRemaining = totalRemaining - consumedSoFar;
if (overallRemaining < blockSize && currentBlockPending == blockSize) {
break;
}

int numBytesConsumable;
if (remaining >= blockSize) {
int blockCount = IntMath.divide(remaining, blockSize, RoundingMode.DOWN);
numBytesConsumable = blockCount * blockSize;
} else if (currentBlockPending < blockSize) {
numBytesConsumable = currentBlockPending;
currentBlockPending = blockSize;
} else {
ChunkSegment chunkSoFar = data.pollLast();
//noinspection ConstantConditions -- covered by peekLast check above
int limit = Math.min(remaining, maxSegmentSize - chunkSoFar.b.size());
ChunkSegment datum = newSegment(buffer, limit);
ChunkSegment plus = chunkSoFar.concat(datum);
data.addLast(plus);
numBytesConsumable = remaining;
currentBlockPending = currentBlockPending - remaining;
}
if (numBytesConsumable <= 0) {
continue;
}

consumedSoFar += consumeBytes(data, numBytesConsumable, buffer);
}
}

return data.toArray(new ChunkSegment[0]);
}

private long consumeBytes(Deque<ChunkSegment> data, int numBytesConsumable, ByteBuffer buffer) {
// either no chunk or most recent chunk is full, start a new one
ChunkSegment peekLast = data.peekLast();
if (peekLast == null || peekLast.b.size() == maxSegmentSize) {
int limit = Math.min(numBytesConsumable, maxSegmentSize);
ChunkSegment datum = newSegment(buffer, limit);
data.addLast(datum);
return limit;
} else {
ChunkSegment chunkSoFar = data.pollLast();
//noinspection ConstantConditions -- covered by peekLast check above
int limit = Math.min(numBytesConsumable, maxSegmentSize - chunkSoFar.b.size());
ChunkSegment datum = newSegment(buffer, limit);
ChunkSegment plus = chunkSoFar.concat(datum);
data.addLast(plus);
return limit;
}
}

private ChunkSegment newSegment(ByteBuffer buffer, int limit) {
final ByteBuffer slice = buffer.slice();
slice.limit(limit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ private long internalWrite(ByteBuffer[] srcs, int srcsOffset, int srcsLength, bo

long begin = writeCtx.getConfirmedBytes().get();
RewindableContent content = RewindableContent.of(srcs, srcsOffset, srcsLength);
ChunkSegment[] data = chunkSegmenter.segmentBuffers(srcs, srcsOffset, srcsLength);
ChunkSegment[] data = chunkSegmenter.segmentBuffers(srcs, srcsOffset, srcsLength, finalize);
if (data.length == 0) {
return 0;
}

List<WriteObjectRequest> messages = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,24 @@

package com.google.cloud.storage;

import static com.google.cloud.storage.TestUtils.assertAll;
import static com.google.common.truth.Truth.assertThat;

import com.google.cloud.storage.ChunkSegmenter.ChunkSegment;
import com.google.cloud.storage.Crc32cValue.Crc32cLengthKnown;
import com.google.common.collect.ImmutableList;
import com.google.common.hash.HashCode;
import com.google.common.hash.Hashing;
import com.google.protobuf.ByteString;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import net.jqwik.api.Arbitraries;
import net.jqwik.api.Arbitrary;
import net.jqwik.api.Example;
import net.jqwik.api.ForAll;
import net.jqwik.api.Property;
import net.jqwik.api.Provide;
Expand Down Expand Up @@ -58,6 +64,113 @@ void chunkIt(@ForAll("TestData") TestData td) {
assertThat(reduce).isAnyOf(Optional.empty(), Optional.of(Crc32cValue.of(td.allCrc32c.asInt())));
}

/**
*
*
* <pre>
* Given 64 bytes, maxSegmentSize: 10, blockSize: 5
* 0 64
* |---------------------------------------------------------------|
* Produce 6 10-byte segments
* |---------|---------|---------|---------|---------|---------|
* </pre>
*/
@Example
void allowUnalignedBlocks_false_1() {
ChunkSegmenter segmenter =
new ChunkSegmenter(Hasher.noop(), ByteStringStrategy.noCopy(), 10, 5);

byte[] bytes = DataGenerator.base64Characters().genBytes(64);
List<ByteString> expected =
ImmutableList.of(
ByteString.copyFrom(bytes, 0, 10),
ByteString.copyFrom(bytes, 10, 10),
ByteString.copyFrom(bytes, 20, 10),
ByteString.copyFrom(bytes, 30, 10),
ByteString.copyFrom(bytes, 40, 10),
ByteString.copyFrom(bytes, 50, 10));

ByteBuffer buf = ByteBuffer.wrap(bytes);

ChunkSegment[] segments = segmenter.segmentBuffers(new ByteBuffer[] {buf}, 0, 1, false);
assertThat(buf.remaining()).isEqualTo(4);
List<ByteString> actual =
Arrays.stream(segments).map(ChunkSegment::getB).collect(Collectors.toList());
assertThat(actual).isEqualTo(expected);
}

/**
*
*
* <pre>
* Given 64 bytes, maxSegmentSize: 14, blockSize: 7
* 0 64
* |---------------------------------------------------------------|
* Produce 4 14-byte segments, and one 7 byte segment
* |-------------|-------------|-------------|-------------|------|
* </pre>
*/
@Example
void allowUnalignedBlocks_false_2() throws Exception {
ChunkSegmenter segmenter =
new ChunkSegmenter(Hasher.noop(), ByteStringStrategy.noCopy(), 14, 7);

byte[] bytes = DataGenerator.base64Characters().genBytes(64);
List<ByteString> expected =
ImmutableList.of(
ByteString.copyFrom(bytes, 0, 14),
ByteString.copyFrom(bytes, 14, 14),
ByteString.copyFrom(bytes, 28, 14),
ByteString.copyFrom(bytes, 42, 14),
ByteString.copyFrom(bytes, 56, 7));

ByteBuffer buf = ByteBuffer.wrap(bytes);

ChunkSegment[] segments = segmenter.segmentBuffers(new ByteBuffer[] {buf}, 0, 1, false);
List<ByteString> actual =
Arrays.stream(segments).map(ChunkSegment::getB).collect(Collectors.toList());
assertAll(
() -> assertThat(buf.remaining()).isEqualTo(1),
() -> assertThat(actual).isEqualTo(expected));
}

/**
*
*
* <pre>
* Given 60 bytes in one buffer and 4 bytes in a second buffer, maxSegmentSize: 14, blockSize: 7
* 0 60 4
* |-----------------------------------------------------------|---|
* Produce 4 14-byte segments, and one 7 byte segment
* |-------------|-------------|-------------|-------------|------|
* </pre>
*/
@Example
void allowUnalignedBlocks_false_3() throws Exception {
ChunkSegmenter segmenter =
new ChunkSegmenter(Hasher.noop(), ByteStringStrategy.noCopy(), 14, 7);

byte[] bytes = DataGenerator.base64Characters().genBytes(64);
List<ByteString> expected =
ImmutableList.of(
ByteString.copyFrom(bytes, 0, 14),
ByteString.copyFrom(bytes, 14, 14),
ByteString.copyFrom(bytes, 28, 14),
ByteString.copyFrom(bytes, 42, 14),
ByteString.copyFrom(bytes, 56, 7));

ByteBuffer buf1 = ByteBuffer.wrap(bytes, 0, 60);
ByteBuffer buf2 = ByteBuffer.wrap(bytes, 60, 4);

ChunkSegment[] segments = segmenter.segmentBuffers(new ByteBuffer[] {buf1, buf2}, 0, 2, false);
List<ByteString> actual =
Arrays.stream(segments).map(ChunkSegment::getB).collect(Collectors.toList());
assertAll(
() -> assertThat(buf1.remaining()).isEqualTo(0),
() -> assertThat(buf2.remaining()).isEqualTo(1),
() -> assertThat(actual).isEqualTo(expected));
}

@Provide("TestData")
static Arbitrary<TestData> arbitraryTestData() {
return Arbitraries.lazyOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import static com.google.common.truth.Truth.assertThat;

import com.google.api.core.ApiFuture;
import com.google.api.core.ApiFutures;
import com.google.api.gax.grpc.GrpcCallContext;
import com.google.api.services.storage.model.StorageObject;
import com.google.cloud.storage.ITUnbufferedResumableUploadTest.ObjectSizes;
import com.google.cloud.storage.TransportCompatibility.Transport;
Expand All @@ -36,6 +38,11 @@
import com.google.cloud.storage.it.runner.registry.Generator;
import com.google.cloud.storage.spi.v1.StorageRpc;
import com.google.common.collect.ImmutableList;
import com.google.storage.v2.Object;
import com.google.storage.v2.StorageClient;
import com.google.storage.v2.WriteObjectRequest;
import com.google.storage.v2.WriteObjectResponse;
import com.google.storage.v2.WriteObjectSpec;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -103,8 +110,13 @@ public void json()
ByteBuffer b = DataGenerator.base64Characters().genByteBuffer(size);

UnbufferedWritableByteChannel open = session.open();
int written = open.write(b);
assertThat(written).isEqualTo(objectSize);
int written1 = open.write(b);
assertThat(written1).isEqualTo(objectSize);
assertThat(b.remaining()).isEqualTo(additional);

// no bytes should be consumed if less than 256KiB
int written2 = open.write(b);
assertThat(written2).isEqualTo(0);
assertThat(b.remaining()).isEqualTo(additional);

int writtenAndClose = open.writeAndClose(b);
Expand All @@ -114,4 +126,70 @@ public void json()
StorageObject storageObject = session.getResult().get(2, TimeUnit.SECONDS);
assertThat(storageObject.getSize()).isEqualTo(BigInteger.valueOf(size));
}

@Test
@Exclude(transports = Transport.HTTP)
public void grpc() throws Exception {
BlobInfo blobInfo = BlobInfo.newBuilder(bucket, generator.randomObjectName()).build();
Opts<ObjectTargetOpt> opts = Opts.empty();
BlobInfo.Builder builder = blobInfo.toBuilder().setMd5(null).setCrc32c(null);
BlobInfo updated = opts.blobInfoMapper().apply(builder).build();

Object object = Conversions.grpc().blobInfo().encode(updated);
Object.Builder objectBuilder =
object
.toBuilder()
// required if the data is changing
.clearChecksums()
// trimmed to shave payload size
.clearGeneration()
.clearMetageneration()
.clearSize()
.clearCreateTime()
.clearUpdateTime();
WriteObjectSpec.Builder specBuilder = WriteObjectSpec.newBuilder().setResource(objectBuilder);

WriteObjectRequest.Builder requestBuilder =
WriteObjectRequest.newBuilder().setWriteObjectSpec(specBuilder);

WriteObjectRequest request = opts.writeObjectRequest().apply(requestBuilder).build();

GrpcCallContext merge = Retrying.newCallContext();
StorageClient storageClient = PackagePrivateMethodWorkarounds.maybeGetStorageClient(storage);
assertThat(storageClient).isNotNull();
ApiFuture<ResumableWrite> start =
ResumableMedia.gapic()
.write()
.resumableWrite(
storageClient.startResumableWriteCallable().withDefaultCallContext(merge), request);

UnbufferedWritableByteChannelSession<WriteObjectResponse> session =
ResumableMedia.gapic()
.write()
.byteChannel(storageClient.writeObjectCallable())
.resumable()
.unbuffered()
.setStartAsync(start)
.build();

int additional = 13;
long size = objectSize + additional;
ByteBuffer b = DataGenerator.base64Characters().genByteBuffer(size);

UnbufferedWritableByteChannel open = session.open();
int written1 = open.write(b);
assertThat(written1).isEqualTo(objectSize);
assertThat(b.remaining()).isEqualTo(additional);

// no bytes should be consumed if less than 256KiB
int written2 = open.write(b);
assertThat(written2).isEqualTo(0);
assertThat(b.remaining()).isEqualTo(additional);

int writtenAndClose = open.writeAndClose(b);
assertThat(writtenAndClose).isEqualTo(additional);
open.close();
WriteObjectResponse resp = session.getResult().get(2, TimeUnit.SECONDS);
assertThat(resp.getResource().getSize()).isEqualTo(size);
}
}
Loading

0 comments on commit 1701fde

Please sign in to comment.