diff --git a/src/it/java/SmbjTest.java b/src/it/java/SmbjTest.java index 6f4c1fca..59f59f50 100644 --- a/src/it/java/SmbjTest.java +++ b/src/it/java/SmbjTest.java @@ -12,14 +12,15 @@ import com.hierynomus.smbj.auth.AuthenticationContext; import com.hierynomus.smbj.common.SMBApiException; import com.hierynomus.smbj.connection.Connection; +import com.hierynomus.smbj.io.InputStreamByteChunkProvider; import com.hierynomus.smbj.session.Session; import com.hierynomus.smbj.share.Directory; import com.hierynomus.smbj.share.DiskShare; -import com.hierynomus.smbj.smb2.SMB2CompletionFilter; -import com.hierynomus.smbj.smb2.SMB2CreateDisposition; -import com.hierynomus.smbj.smb2.SMB2ShareAccess; -import com.hierynomus.smbj.smb2.messages.SMB2ChangeNotifyRequest; -import com.hierynomus.smbj.smb2.messages.SMB2ChangeNotifyResponse; +import com.hierynomus.mssmb2.SMB2CompletionFilter; +import com.hierynomus.mssmb2.SMB2CreateDisposition; +import com.hierynomus.mssmb2.SMB2ShareAccess; +import com.hierynomus.mssmb2.messages.SMB2ChangeNotifyRequest; +import com.hierynomus.mssmb2.messages.SMB2ChangeNotifyResponse; import com.hierynomus.smbj.transport.TransportException; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.junit.BeforeClass; @@ -27,12 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.io.UnsupportedEncodingException; +import java.io.*; import java.net.MalformedURLException; import java.net.URISyntaxException; import java.net.URL; @@ -40,20 +36,13 @@ import java.nio.file.Files; import java.nio.file.Paths; import java.security.Security; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; +import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; /** * Integration test Pre-Req @@ -180,7 +169,7 @@ public void testBasic() throws IOException, SMBApiException, URISyntaxException // Download and compare with originals File tmpFile1 = File.createTempFile("smbj", "junit"); try (OutputStream os = new FileOutputStream(tmpFile1)) { - share.read(fix(TEST_PATH + "/1/" + file1), os, null); + read(share, fix(TEST_PATH + "/1/" + file1), os); } assertFileContent("testfiles/medium.txt", tmpFile1.getAbsolutePath()); @@ -213,6 +202,89 @@ public void testBasic() throws IOException, SMBApiException, URISyntaxException } } + @Test + public void testWithSingleConnectionMultipleClients() throws IOException, SMBApiException, URISyntaxException { + logger.info("Connect {},{},{},{}", ci.host, ci.user, ci.domain, ci.sharePath); + SMBClient client = new SMBClient(); + final Connection connection = client.connect(ci.host); + AuthenticationContext ac = new AuthenticationContext( + ci.user, + ci.password == null ? new char[0] : ci.password.toCharArray(), + ci.domain); + Session session = connection.authenticate(ac); + + try (DiskShare share = (DiskShare)session.connectShare(ci.sharePath)) { + try { + share.rmdir(TEST_PATH, true); + } catch (SMBApiException sae) { + if (sae.getStatus() != NtStatus.STATUS_OBJECT_NAME_NOT_FOUND) { + throw sae; + } + } + share.mkdir(fix(TEST_PATH)); + assertTrue(share.folderExists(fix(TEST_PATH))); + ExecutorService executor = Executors.newFixedThreadPool(10); + final List exceptions = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + executor.execute(new Runnable() { + @Override + public void run() { + String fl = UUID.randomUUID().toString() + ".txt"; + byte[] expectedBytes = generateRandomBytes(100000); + InputStream is = new ByteArrayInputStream(expectedBytes); + String path = fix(TEST_PATH + "\\" + fl); + try { + File localTmpFile = File.createTempFile("smbj", "junit"); + com.hierynomus.smbj.share.File file = share.openFile(path, + EnumSet.of(AccessMask.GENERIC_WRITE), SMB2CreateDisposition.FILE_OVERWRITE_IF); + + OutputStream out = file.getOutputStream(); + int numRead = -1; + byte[] buf = new byte[10000]; + while ((numRead = is.read(buf)) != -1) { + out.write(buf, 0, numRead); + } + out.close(); + file.close(); + + try (OutputStream os = new FileOutputStream(localTmpFile)) { + read(share, path, os); + } + byte[] actualBytes = Files.readAllBytes(localTmpFile.toPath()); + assertArrayEquals(expectedBytes, actualBytes); + localTmpFile.delete(); + share.rm(path); + } catch (Exception e) { + exceptions.add(e); + } + } + }); + } + executor.shutdown(); + while (!executor.isTerminated()) { + try { + executor.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + // ignore + } + } + for (Throwable e : exceptions) { + e.printStackTrace(); + } + assertEquals("Errors encountered with multiple threads", 0, exceptions.size()); + share.rmdir(TEST_PATH, true); + assertFalse(share.folderExists(fix(TEST_PATH))); + } finally { + session.close(); + } + } + + private static byte[] generateRandomBytes(final int size) { + byte[] randomBytes = new byte[size]; + new Random().nextBytes(randomBytes); + return randomBytes; + } + @Test public void testRpc() throws IOException, SMBApiException, URISyntaxException { logger.info("Connect {},{},{},{}", ci.host, ci.user, ci.domain, ci.sharePath); @@ -299,7 +371,7 @@ List notify( Connection connection = session.getConnection(); SMB2ChangeNotifyRequest cnr = new SMB2ChangeNotifyRequest( - connection.getNegotiatedDialect(), + connection.getNegotiatedProtocol().getDialect(), session.getSessionId(), share.getTreeConnect().getTreeId(), directory.getFileId(), EnumSet.of( @@ -335,9 +407,28 @@ private void assertFileContent(String localResource, String downloadedFile) void write(DiskShare share, String remotePath, String localResource) throws IOException, SMBApiException { logger.debug("Writing {}, {} to {}", localResource, this.getClass().getResource(localResource), remotePath); + com.hierynomus.smbj.share.File file = share.openFile(remotePath, + EnumSet.of(AccessMask.GENERIC_WRITE), SMB2CreateDisposition.FILE_OVERWRITE_IF); try (InputStream is = this.getClass().getResourceAsStream(localResource)) { - share.write(remotePath, true, is, null); + file.write(new InputStreamByteChunkProvider(is)); } + file.close(); + } + + void write(DiskShare share, String remotePath, InputStream is) + throws IOException, SMBApiException { + com.hierynomus.smbj.share.File file = share.openFile(remotePath, + EnumSet.of(AccessMask.GENERIC_WRITE), SMB2CreateDisposition.FILE_OVERWRITE_IF); + file.write(new InputStreamByteChunkProvider(is)); + file.close(); + } + + void read(DiskShare share, String remotePath, OutputStream os) + throws IOException, SMBApiException { + com.hierynomus.smbj.share.File file = share.openFile(remotePath, + EnumSet.of(AccessMask.GENERIC_READ), SMB2CreateDisposition.FILE_OPEN); + file.read(os); + file.close(); } private void assertFilesInPathEquals(DiskShare share, String[] expected, String path) diff --git a/src/main/java/com/hierynomus/smbj/connection/SequenceWindow.java b/src/main/java/com/hierynomus/smbj/connection/SequenceWindow.java index 2a417898..4d706b3b 100644 --- a/src/main/java/com/hierynomus/smbj/connection/SequenceWindow.java +++ b/src/main/java/com/hierynomus/smbj/connection/SequenceWindow.java @@ -18,6 +18,7 @@ import com.hierynomus.smbj.common.SMBRuntimeException; import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; /** @@ -39,18 +40,27 @@ class SequenceWindow { static final int PREFERRED_MINIMUM_CREDITS = 512; private AtomicLong lowestAvailable = new AtomicLong(0); private Semaphore available = new Semaphore(1); - - long get() { - if (available.tryAcquire()) { - return lowestAvailable.getAndIncrement(); + private static final long MAX_WAIT = 5000; + + public long get() { + try { + if (available.tryAcquire(MAX_WAIT, TimeUnit.MILLISECONDS)) { + return lowestAvailable.getAndIncrement(); + } + } catch (InterruptedException e) { + //ignore } throw new SMBRuntimeException("No more credits available to hand out sequence number"); } - long[] get(int credits) { - if (available.tryAcquire(credits)) { - long lowest = lowestAvailable.getAndAdd(credits); - return range(lowest, lowest + credits); + public long[] get(int credits) { + try { + if (available.tryAcquire(credits, MAX_WAIT, TimeUnit.MILLISECONDS)) { + long lowest = lowestAvailable.getAndAdd(credits); + return range(lowest, lowest + credits); + } + } catch (InterruptedException e) { + //ignore } throw new SMBRuntimeException("Not enough credits (" + available.availablePermits() + " available) to hand out " + credits + " sequence numbers"); } @@ -87,11 +97,21 @@ public boolean tryAcquire() { return true; } + @Override + public boolean tryAcquire(long timeout, TimeUnit unit) { + return true; + } + @Override public boolean tryAcquire(int permits) { return true; } + @Override + public boolean tryAcquire(int permits, long timeout, TimeUnit unit) { + return true; + } + @Override public void release(int permits) { // no-op diff --git a/src/main/java/com/hierynomus/smbj/io/InputStreamByteChunkProvider.java b/src/main/java/com/hierynomus/smbj/io/InputStreamByteChunkProvider.java new file mode 100644 index 00000000..cc143276 --- /dev/null +++ b/src/main/java/com/hierynomus/smbj/io/InputStreamByteChunkProvider.java @@ -0,0 +1,56 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.smbj.io; + +import com.hierynomus.smbj.common.SMBRuntimeException; + +import java.io.*; + +public class InputStreamByteChunkProvider extends ByteChunkProvider { + + private BufferedInputStream is; + + public InputStreamByteChunkProvider(InputStream is) { + if (is instanceof BufferedInputStream) + this.is = (BufferedInputStream) is; + else + this.is = new BufferedInputStream(is); + } + + @Override + protected int getChunk(byte[] chunk) throws IOException { + int count = 0; + int read = 0; + while (count < CHUNK_SIZE && ((read = is.read(chunk, count, CHUNK_SIZE - count)) != -1)) { + count += read; + } + return count; + } + + @Override + public int bytesLeft() { + try { + return is.available(); + } catch (IOException e) { + throw new SMBRuntimeException(e); + } + } + + @Override + public boolean isAvailable() { + return bytesLeft() > 0; + } +} diff --git a/src/main/java/com/hierynomus/smbj/share/File.java b/src/main/java/com/hierynomus/smbj/share/File.java index cf2eb4a5..94943496 100644 --- a/src/main/java/com/hierynomus/smbj/share/File.java +++ b/src/main/java/com/hierynomus/smbj/share/File.java @@ -17,8 +17,6 @@ import com.hierynomus.mserref.NtStatus; import com.hierynomus.mssmb2.SMB2FileId; -import com.hierynomus.mssmb2.messages.SMB2ReadRequest; -import com.hierynomus.mssmb2.messages.SMB2ReadResponse; import com.hierynomus.mssmb2.messages.SMB2WriteRequest; import com.hierynomus.mssmb2.messages.SMB2WriteResponse; import com.hierynomus.protocol.commons.concurrent.Futures; @@ -32,6 +30,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.util.concurrent.Future; @@ -57,61 +56,43 @@ public void write(ByteChunkProvider provider, ProgressListener progressListener) throw new SMBApiException(wresp.getHeader().getStatus(), "Write failed for " + this); } if (progressListener != null) progressListener.onProgressChanged(wresp.getBytesWritten(), provider.getOffset()); - } } -// public void write(InputStream srcStream, ProgressListener progressListener) throws IOException, SMBApiException { -// -// Session session = treeConnect.getSession(); -// Connection connection = session.getConnection(); -// -// byte[] buf = new byte[connection.getNegotiatedProtocol().getMaxWriteSize()]; -// int numRead = -1; -// int offset = 0; -// -// while ((numRead = srcStream.read(buf)) != -1) { -// //logger.debug("Writing {} bytes", numRead); -// SMB2WriteRequest wreq = new SMB2WriteRequest(connection.getNegotiatedProtocol().getDialect(), getFileId(), -// session.getSessionId(), treeConnect.getTreeId(), -// buf, numRead, offset, 0); -// Future writeFuture = connection.send(wreq); -// SMB2WriteResponse wresp = Futures.get(writeFuture, TransportException.Wrapper); -// -// if (wresp.getHeader().getStatus() != NtStatus.STATUS_SUCCESS) { -// throw new SMBApiException(wresp.getHeader().getStatus(), "Write failed for " + this); -// } -// offset += numRead; -// if (progressListener != null) progressListener.onProgressChanged(offset, -1); -// } -// } + public void write(ByteChunkProvider provider) throws IOException { + write(provider, null); + } + + public void read(OutputStream destStream) throws IOException { + read(destStream, null); + } - public void read(OutputStream destStream, ProgressListener progressListener) throws IOException, - SMBApiException { + public void read(OutputStream destStream, ProgressListener progressListener) throws IOException { Session session = treeConnect.getSession(); Connection connection = session.getConnection(); + InputStream is = getInputStream(progressListener); + int numRead = -1; + byte[] buf = new byte[connection.getNegotiatedProtocol().getMaxWriteSize()]; + while ((numRead = is.read(buf)) != -1) { + destStream.write(buf, 0, numRead); + } + is.close(); + } - long offset = 0; - SMB2ReadRequest rreq = new SMB2ReadRequest(connection.getNegotiatedProtocol(), getFileId(), - session.getSessionId(), treeConnect.getTreeId(), offset); + public InputStream getInputStream() { + return getInputStream(null); + } - Future readResponseFuture = connection.send(rreq); - SMB2ReadResponse rresp = Futures.get(readResponseFuture, TransportException.Wrapper); + private InputStream getInputStream(final ProgressListener listener) { + return new FileInputStream(fileId, treeConnect, listener); + } - while (rresp.getHeader().getStatus() == NtStatus.STATUS_SUCCESS && - rresp.getHeader().getStatus() != NtStatus.STATUS_END_OF_FILE) { - destStream.write(rresp.getData()); - offset += rresp.getDataLength(); - rreq = new SMB2ReadRequest(connection.getNegotiatedProtocol(), getFileId(), - session.getSessionId(), treeConnect.getTreeId(), offset); - readResponseFuture = connection.send(rreq); - rresp = Futures.get(readResponseFuture, TransportException.Wrapper); - if (progressListener != null) progressListener.onProgressChanged(offset, -1); - } + public OutputStream getOutputStream() { + return getOutputStream(null); + } - if (rresp.getHeader().getStatus() != NtStatus.STATUS_END_OF_FILE) { - throw new SMBApiException(rresp.getHeader().getStatus(), "Read failed for " + this); - } + private OutputStream getOutputStream(final ProgressListener listener) { + return new FileOutputStream(fileId, treeConnect, listener); } @Override @@ -121,5 +102,4 @@ public String toString() { ", fileName='" + fileName + '\'' + '}'; } - } diff --git a/src/main/java/com/hierynomus/smbj/share/FileInputStream.java b/src/main/java/com/hierynomus/smbj/share/FileInputStream.java new file mode 100644 index 00000000..651eb806 --- /dev/null +++ b/src/main/java/com/hierynomus/smbj/share/FileInputStream.java @@ -0,0 +1,126 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.smbj.share; + +import com.hierynomus.mserref.NtStatus; +import com.hierynomus.mssmb2.SMB2FileId; +import com.hierynomus.mssmb2.messages.SMB2ReadRequest; +import com.hierynomus.mssmb2.messages.SMB2ReadResponse; +import com.hierynomus.protocol.commons.concurrent.Futures; +import com.hierynomus.smbj.ProgressListener; +import com.hierynomus.smbj.common.SMBApiException; +import com.hierynomus.smbj.connection.Connection; +import com.hierynomus.smbj.session.Session; +import com.hierynomus.smbj.transport.TransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.Future; + +public class FileInputStream extends InputStream { + + protected TreeConnect treeConnect; + private Session session; + private Connection connection; + private SMB2FileId fileId; + private long offset = 0; + private int curr = 0; + private byte[] buf; + private ProgressListener progressListener; + private boolean isClosed; + private Future nextResponse; + + private static final Logger logger = LoggerFactory.getLogger(FileInputStream.class); + + public FileInputStream(SMB2FileId fileId, TreeConnect treeConnect, ProgressListener progressListener) { + this.treeConnect = treeConnect; + this.fileId = fileId; + this.session = treeConnect.getSession(); + this.connection = session.getConnection(); + this.progressListener = progressListener; + } + + @Override + public int read() throws IOException { + if (buf == null || curr >= buf.length) { + loadBuffer(); + } + if (isClosed) return -1; + ++curr; + return buf[curr - 1] & 0xFF; + } + + @Override + public int read(byte b[]) throws IOException { + return read(b, 0, b.length); + } + + @Override + public int read(byte b[], int off, int len) throws IOException { + if (buf == null || curr >= buf.length) { + loadBuffer(); + } + if (isClosed) return -1; + int l = buf.length - curr > len ? len : buf.length - curr; + System.arraycopy(buf, curr, b, off, l); + curr = curr + l; + return l; + } + + @Override + public void close() throws IOException { + isClosed = true; + session = null; + connection = null; + buf = null; + } + + @Override + public int available() throws IOException { + throw new IOException("Available not supported"); + } + + private void loadBuffer() throws IOException { + + if (nextResponse == null) + nextResponse = sendRequest(); + + SMB2ReadResponse res = Futures.get(nextResponse, TransportException.Wrapper); + if (res.getHeader().getStatus() == NtStatus.STATUS_SUCCESS) { + buf = res.getData(); + curr = 0; + offset += res.getDataLength(); + if (progressListener != null) progressListener.onProgressChanged(offset, -1); + } + if (res.getHeader().getStatus() == NtStatus.STATUS_END_OF_FILE) { + logger.debug("EOF, {} bytes read", offset); + isClosed = true; + return; + } + if (res.getHeader().getStatus() != NtStatus.STATUS_SUCCESS) { + throw new SMBApiException(res.getHeader().getStatus(), "Read failed for " + this); + } + nextResponse = sendRequest(); + } + + private Future sendRequest() throws IOException { + SMB2ReadRequest rreq = new SMB2ReadRequest(connection.getNegotiatedProtocol(), fileId, + session.getSessionId(), treeConnect.getTreeId(), offset); + return connection.send(rreq); + } +} diff --git a/src/main/java/com/hierynomus/smbj/share/FileOutputStream.java b/src/main/java/com/hierynomus/smbj/share/FileOutputStream.java new file mode 100644 index 00000000..0ca2d2b3 --- /dev/null +++ b/src/main/java/com/hierynomus/smbj/share/FileOutputStream.java @@ -0,0 +1,178 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.smbj.share; + +import com.hierynomus.mserref.NtStatus; +import com.hierynomus.mssmb2.SMB2FileId; +import com.hierynomus.mssmb2.messages.SMB2WriteRequest; +import com.hierynomus.mssmb2.messages.SMB2WriteResponse; +import com.hierynomus.protocol.commons.concurrent.Futures; +import com.hierynomus.smbj.ProgressListener; +import com.hierynomus.smbj.common.SMBApiException; +import com.hierynomus.smbj.connection.Connection; +import com.hierynomus.smbj.io.ByteChunkProvider; +import com.hierynomus.smbj.session.Session; +import com.hierynomus.smbj.transport.TransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.concurrent.Future; + +public class FileOutputStream extends OutputStream { + + private TreeConnect treeConnect; + private SMB2FileId fileId; + private Session session; + private Connection connection; + private int maxWriteSize; + private ProgressListener progressListener; + private boolean isClosed = false; + private ByteArrayProvider provider; + + private static final Logger logger = LoggerFactory.getLogger(FileOutputStream.class); + + public FileOutputStream(SMB2FileId fileId, TreeConnect treeConnect, ProgressListener progressListener) { + this.treeConnect = treeConnect; + this.fileId = fileId; + this.session = treeConnect.getSession(); + this.connection = session.getConnection(); + this.progressListener = progressListener; + this.maxWriteSize = connection.getNegotiatedProtocol().getMaxWriteSize(); + this.provider = new ByteArrayProvider(this.maxWriteSize); + } + + @Override + public void write(int b) throws IOException { + verifyConnectionNotClosed(); + + if (provider.getCurrentSize() < maxWriteSize) { + provider.getBuf()[provider.getCurrentSize()] = (byte) b; + provider.incCurrentSize(); + } + if (provider.getCurrentSize() == maxWriteSize) flush(); + } + + @Override + public void write(byte b[]) throws IOException { + write(b, 0, b.length); + } + + @Override + public void write(byte b[], int off, int len) throws IOException { + verifyConnectionNotClosed(); + if (provider.getCurrentSize() < maxWriteSize) { + System.arraycopy(b, off, provider.getBuf(), provider.getCurrentSize(), len); + provider.incCurrentSize(len); + } + if (provider.getCurrentSize() == maxWriteSize) flush(); + } + + @Override + public void flush() throws IOException { + verifyConnectionNotClosed(); + + while (provider.isAvailable()) { + SMB2WriteRequest wreq = new SMB2WriteRequest(connection.getNegotiatedProtocol().getDialect(), fileId, + session.getSessionId(), treeConnect.getTreeId(), provider, connection.getNegotiatedProtocol().getMaxWriteSize()); + Future writeFuture = connection.send(wreq); + SMB2WriteResponse wresp = Futures.get(writeFuture, TransportException.Wrapper); + if (wresp.getHeader().getStatus() != NtStatus.STATUS_SUCCESS) { + throw new SMBApiException(wresp.getHeader().getStatus(), "Write failed for " + this); + } + provider.resetCurrentSize(); + provider.resetReadPosition(); + if (progressListener != null) + progressListener.onProgressChanged(wresp.getBytesWritten(), provider.getOffset()); + } + } + + @Override + public void close() throws IOException { + flush(); + isClosed = true; + provider.clean(); + treeConnect = null; + session = null; + connection = null; + logger.debug("EOF, {} bytes written", provider.getOffset()); + } + + private void verifyConnectionNotClosed() throws IOException { + if (isClosed) throw new IOException("Stream is closed"); + } + + private static class ByteArrayProvider extends ByteChunkProvider { + + private byte[] buf; + private int maxWriteSize; + private int currentSize; + private int readPosition; + + private ByteArrayProvider(int maxWriteSize) { + this.maxWriteSize = maxWriteSize; + } + + @Override + public boolean isAvailable() { + return currentSize - readPosition > 0; + } + + @Override + protected int getChunk(byte[] chunk) throws IOException { + int len = currentSize - readPosition < chunk.length ? currentSize - readPosition : chunk.length; + System.arraycopy(buf, readPosition, chunk, 0, len); + readPosition = readPosition + len; + return len; + } + + @Override + public int bytesLeft() { + return currentSize - readPosition; + } + + private byte[] getBuf() { + if (buf == null) + buf = new byte[maxWriteSize]; + return buf; + } + + private void clean() { + buf = null; + } + + private int getCurrentSize() { + return currentSize; + } + + private void incCurrentSize() { + incCurrentSize(1); + } + + private void incCurrentSize(int i) { + currentSize += i; + } + + private void resetCurrentSize() { + currentSize = 0; + } + + private void resetReadPosition() { + readPosition = 0; + } + } +}