Skip to content

Commit

Permalink
Updates SourceBufferingInputStream to conform to InputStream contract (
Browse files Browse the repository at this point in the history
  • Loading branch information
rharter authored and JakeWharton committed Apr 19, 2018
1 parent 8da44f3 commit d959c4e
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ final class SourceBufferingInputStream extends InputStream {
private final BufferedSource source;
private final Buffer buffer;
private long position;
private long mark = -1;
private long markPosition = -1;
private long markLimit = -1;

SourceBufferingInputStream(BufferedSource source) {
this.source = source;
Expand All @@ -42,34 +43,48 @@ final class SourceBufferingInputStream extends InputStream {
private final Buffer temp = new Buffer();
private int copyTo(byte[] sink, int offset, int byteCount) {
// TODO replace this with https://github.com/square/okio/issues/362
buffer.copyTo(temp, offset, byteCount);
// `copyTo` treats offset as the read position, `read` treats offset as the write offset.
buffer.copyTo(temp, position, byteCount);
return temp.read(sink, offset, byteCount);
}

@Override public int read() throws IOException {
source.require(position + 1);
if (!source.request(position + 1)) {
return -1;
}
byte value = buffer.getByte(position++);
if (position > mark) {
mark = -1;
if (position > markLimit) {
markPosition = -1;
}
return value;
}

@Override public int read(@NonNull byte[] b, int off, int len) throws IOException {
source.require(position + len);
int copied = /*buffer.*/copyTo(b, off, len);
if (off < 0 || len < 0 || len > b.length - off) {
throw new IndexOutOfBoundsException();
} else if (len == 0) {
return 0;
}

int count = len;
if (!source.request(position + count)) {
count = available();
}
if (count == 0) return -1;

int copied = /*buffer.*/copyTo(b, off, count);
position += copied;
if (position > mark) {
mark = -1;
if (position > markLimit) {
markPosition = -1;
}
return copied;
}

@Override public long skip(long n) throws IOException {
source.require(position + n);
position += n;
if (position > mark) {
mark = -1;
if (position > markLimit) {
markPosition = -1;
}
return n;
}
Expand All @@ -79,15 +94,17 @@ private int copyTo(byte[] sink, int offset, int byteCount) {
}

@Override public void mark(int readlimit) {
mark = position + readlimit;
markPosition = position;
markLimit = position + readlimit;
}

@Override public void reset() throws IOException {
if (mark == -1) {
if (markPosition == -1) {
throw new IOException("No mark or mark expired");
}
position = mark;
mark = -1;
position = markPosition;
markPosition = -1;
markLimit = -1;
}

@Override public int available() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import okio.Timeout;
import org.junit.Test;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

public final class SourceBufferingInputStreamTest {
@Test public void replay() throws IOException {
Expand Down Expand Up @@ -45,6 +47,59 @@ public final class SourceBufferingInputStreamTest {
assertEquals(13, stream.available());
}

@Test public void read() throws IOException {
Buffer data = new Buffer().writeUtf8("Hello, world!");
BufferedSource source = Okio.buffer((Source) data);
InputStream stream = new SourceBufferingInputStream(source);
byte[] bytes = new byte[5];
int len = stream.read(bytes);
assertEquals(5, len);
assertArrayEquals(new byte[] {'H', 'e', 'l', 'l', 'o'}, bytes);

len = stream.read(bytes);
assertEquals(5, len);
assertArrayEquals(new byte[] {',', ' ', 'w', 'o', 'r'}, bytes);

len = stream.read(bytes);
assertEquals(3, len);
// Last two bytes are out of range so untouched from previous run.
assertArrayEquals(new byte[] {'l', 'd', '!', 'o', 'r'}, bytes);

len = stream.read(bytes);
assertEquals(-1, len);
}

@Test public void markReset() throws IOException {
Buffer data = new Buffer().writeUtf8("Hello, world!");
BufferedSource source = Okio.buffer((Source) data);
InputStream stream = new SourceBufferingInputStream(source);
stream.mark(2);

byte[] bytes = new byte[4];
int len = stream.read(bytes, 0, 2);
assertEquals(2, len);
assertArrayEquals(new byte[] {'H', 'e', 0, 0}, bytes);

len = stream.read(bytes);
assertEquals(len, 4);
assertArrayEquals(new byte[] {'l', 'l', 'o', ','}, bytes);

try {
stream.reset();
fail("expected IOException on reset");
} catch (IOException expected) {}

stream.mark(2);
len = stream.read(bytes, 0, 2);
assertEquals(2, len);
assertArrayEquals(new byte[] {' ', 'w', 'o', ','}, bytes);

stream.reset();
len = stream.read(bytes);
assertEquals(4, len);
assertArrayEquals(new byte[] {' ', 'w', 'o', 'r'}, bytes);
}

/** Prevents a consumer from reading large chunks and exercises edge cases. */
private static final class OneByteSource implements Source {
private final Source upstream;
Expand Down

0 comments on commit d959c4e

Please sign in to comment.