From 93846ea4e06e3af3b10a43c8ac4bc1101c97e5d3 Mon Sep 17 00:00:00 2001 From: Yves Langisch Date: Thu, 23 Dec 2021 18:12:55 +0100 Subject: [PATCH] Review and add tests. Signed-off-by: David Kocher --- .../net/schmizz/sshj/sftp/RemoteFile.java | 16 ++-- .../hierynomus/sshj/sftp/RemoteFileTest.java | 90 +++++++++++++++++++ 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java index 7a0febcd8..a5558030e 100644 --- a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java +++ b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java @@ -224,7 +224,7 @@ public class ReadAheadRemoteFileInputStream private final byte[] b = new byte[1]; private final int maxUnconfirmedReads; - private final long maxOffset; + private final long readAheadLimit; private final Queue> unconfirmedReads = new LinkedList>(); private final Queue unconfirmedReadOffsets = new LinkedList(); @@ -240,15 +240,15 @@ public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads) { * * @param maxUnconfirmedReads Maximum number of unconfirmed requests to send * @param fileOffset Initial offset in file to read from - * @param maxLength Maximum length to read + * @param readAheadLimit Read ahead is disabled after this limit has been reached */ - public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset, long maxLength) { + public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset, long readAheadLimit) { assert 0 <= maxUnconfirmedReads; assert 0 <= fileOffset; this.maxUnconfirmedReads = maxUnconfirmedReads; this.requestOffset = this.responseOffset = fileOffset; - this.maxOffset = maxLength > 0 ? fileOffset + maxLength : Long.MAX_VALUE; + this.readAheadLimit = readAheadLimit > 0 ? fileOffset + readAheadLimit : Long.MAX_VALUE; } private ByteArrayInputStream pending = new ByteArrayInputStream(new byte[0]); @@ -299,10 +299,16 @@ public int read(byte[] into, int off, int len) throws IOException { while (unconfirmedReads.size() <= maxUnconfirmedReads) { // Send read requests as long as there is no EOF and we have not reached the maximum parallelism int reqLen = Math.max(1024, len); // don't be shy! + if (readAheadLimit > requestOffset) { + long remaining = readAheadLimit - requestOffset; + if (reqLen > remaining) { + reqLen = (int) remaining; + } + } unconfirmedReads.add(RemoteFile.this.asyncRead(requestOffset, reqLen)); unconfirmedReadOffsets.add(requestOffset); requestOffset += reqLen; - if (requestOffset >= maxOffset) { + if (requestOffset >= readAheadLimit) { break; } } diff --git a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java index 3436af42f..949a917c0 100644 --- a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java +++ b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java @@ -20,6 +20,7 @@ import net.schmizz.sshj.sftp.OpenMode; import net.schmizz.sshj.sftp.RemoteFile; import net.schmizz.sshj.sftp.SFTPEngine; +import net.schmizz.sshj.sftp.SFTPException; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -32,6 +33,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; public class RemoteFileTest { @Rule @@ -84,4 +86,92 @@ public void shouldNotGoOutOfBoundsInReadAheadInputStream() throws IOException { assertThat("The written and received data should match", data, equalTo(test2)); } + + @Test + public void shouldNotReadAheadAfterLimitInputStream() throws IOException { + SSHClient ssh = fixture.setupConnectedDefaultClient(); + ssh.authPassword("test", "test"); + SFTPEngine sftp = new SFTPEngine(ssh).init(); + + RemoteFile rf; + File file = temp.newFile("SftpReadAheadLimitTest.bin"); + rf = sftp.open(file.getPath(), EnumSet.of(OpenMode.WRITE, OpenMode.CREAT)); + byte[] data = new byte[8192]; + new Random(53).nextBytes(data); + data[3072] = 1; + rf.write(0, data, 0, data.length); + rf.close(); + + assertThat("The file should exist", file.exists()); + + rf = sftp.open(file.getPath()); + InputStream rs = rf.new ReadAheadRemoteFileInputStream(16 /*maxUnconfirmedReads*/,0, 3072); + + byte[] test = new byte[4097]; + int n = 0; + + while (n < 2048) { + n += rs.read(test, n, 2048 - n); + } + + rf.close(); + + while (n < 3072) { + n += rs.read(test, n, 3072 - n); + } + + assertThat("buffer overrun", test[3072] == 0); + + try { + rs.read(test, n, test.length - n); + fail("Content must not be buffered"); + } catch (SFTPException e){ + // expected + } + } + + @Test + public void limitedReadAheadInputStream() throws IOException { + SSHClient ssh = fixture.setupConnectedDefaultClient(); + ssh.authPassword("test", "test"); + SFTPEngine sftp = new SFTPEngine(ssh).init(); + + RemoteFile rf; + File file = temp.newFile("SftpReadAheadLimitedTest.bin"); + rf = sftp.open(file.getPath(), EnumSet.of(OpenMode.WRITE, OpenMode.CREAT)); + byte[] data = new byte[8192]; + new Random(53).nextBytes(data); + data[3072] = 1; + rf.write(0, data, 0, data.length); + rf.close(); + + assertThat("The file should exist", file.exists()); + + rf = sftp.open(file.getPath()); + InputStream rs = rf.new ReadAheadRemoteFileInputStream(16 /*maxUnconfirmedReads*/,0, 3072); + + byte[] test = new byte[4097]; + int n = 0; + + while (n < 2048) { + n += rs.read(test, n, 2048 - n); + } + + while (n < 3072) { + n += rs.read(test, n, 3072 - n); + } + + assertThat("buffer overrun", test[3072] == 0); + + n += rs.read(test, n, test.length - n); // --> ArrayIndexOutOfBoundsException + + byte[] test2 = new byte[data.length]; + System.arraycopy(test, 0, test2, 0, test.length); + + while (n < data.length) { + n += rs.read(test2, n, data.length - n); + } + + assertThat("The written and received data should match", data, equalTo(test2)); + } }