Skip to content

Commit

Permalink
Slightly reduce the locking in Connection.send and DirectTcpTransport…
Browse files Browse the repository at this point in the history
… (fixes \#732)
  • Loading branch information
hierynomus committed May 8, 2023
1 parent e516cad commit 537d692
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 34 deletions.
35 changes: 18 additions & 17 deletions src/main/java/com/hierynomus/smbj/connection/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -210,32 +210,33 @@ public Session createSession(AuthenticationContext context) {
* @throws TransportException When a transport level error occurred
*/
public <T extends SMB2Packet> Future<T> send(SMB2Packet packet) throws TransportException {
lock.lock();
Future<T> f = null;
try {
if (!(packet.getPacket() instanceof SMB2Cancel)) {
int availableCredits = sequenceWindow.available();
int grantCredits = calculateGrantedCredits(packet, availableCredits);
if (availableCredits == 0) {
logger.warn(
"There are no credits left to send {}, will block until there are more credits available.",
packet.getHeader().getMessage());
// Need to lock around the sequence window calls to ensure no credits get stolen by another thread
lock.lock();
try {
int availableCredits = sequenceWindow.available();
int grantCredits = calculateGrantedCredits(packet, availableCredits);
if (availableCredits == 0) {
logger.warn(
"There are no credits left to send {}, will block until there are more credits available.",
packet.getHeader().getMessage());
}
long[] messageIds = sequenceWindow.get(grantCredits);
packet.getHeader().setMessageId(messageIds[0]);
packet.getHeader().setCreditRequest(Math.max(SequenceWindow.PREFERRED_MINIMUM_CREDITS - availableCredits - grantCredits,
grantCredits));
logger.debug("Granted {} (out of {}) credits to {}", grantCredits, availableCredits, packet);
} finally {
lock.unlock();
}
long[] messageIds = sequenceWindow.get(grantCredits);
packet.getHeader().setMessageId(messageIds[0]);
logger.debug("Granted {} (out of {}) credits to {}", grantCredits, availableCredits, packet);
packet.getHeader().setCreditRequest(Math
.max(SequenceWindow.PREFERRED_MINIMUM_CREDITS - availableCredits - grantCredits, grantCredits));

Request request = new Request(packet.getPacket(), messageIds[0], UUID.randomUUID());
Request request = new Request(packet.getPacket(), packet.getHeader().getMessageId(), UUID.randomUUID());
outstandingRequests.registerOutstanding(request);
f = request.getFuture(new CancelRequest(request, packet.getHeader().getSessionId()));
}
transport.write(packet);
return f;
} finally {
lock.unlock();
}
}

<T extends SMB2Packet> T sendAndReceive(SMB2Packet packet) throws TransportException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import static java.lang.String.format;

Expand All @@ -42,7 +44,7 @@ public class DirectTcpTransport<D extends PacketData<?>, P extends Packet<?>> im
private final Logger logger = LoggerFactory.getLogger(this.getClass());
private final PacketHandlers<D, P> handlers;

private final ReentrantLock writeLock = new ReentrantLock();
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

private SocketFactory socketFactory = new ProxySocketFactory();
private int soTimeout;
Expand All @@ -62,23 +64,28 @@ public DirectTcpTransport(SocketFactory socketFactory, int soTimeout, PacketHand
@Override
public void write(P packet) throws TransportException {
logger.trace("Acquiring write lock to send packet << {} >>", packet);
writeLock.lock();
// isConnected only locks readlock, so check first and check once write lock
// is acquired to prevent race
if (!isConnected()) {
throw new TransportException(format("Cannot write %s as transport is disconnected", packet));
}

lock.writeLock().lock();
try {
if (!isConnected()) {
throw new TransportException(format("Cannot write %s as transport is disconnected", packet));
}
try {
logger.debug("Writing packet {}", packet);
Buffer<?> packetData = handlers.getSerializer().write(packet);
writeDirectTcpPacketHeader(packetData.available());
writePacketData(packetData);
output.flush();
logger.trace("Packet {} sent, lock released.", packet);
} catch (IOException ioe) {
throw new TransportException(ioe);
throw new TransportException(format("Cannot write %s as transport got disconnected", packet));
}

logger.debug("Writing packet {}", packet);
Buffer<?> packetData = handlers.getSerializer().write(packet);
writeDirectTcpPacketHeader(packetData.available());
writePacketData(packetData);
output.flush();
logger.trace("Packet {} sent, lock released.", packet);
} catch (IOException ioe) {
throw new TransportException(ioe);
} finally {
writeLock.unlock();
lock.writeLock().unlock();
}
}

Expand All @@ -99,32 +106,45 @@ private void initWithSocket(String remoteHostname) throws IOException {

@Override
public void disconnect() throws IOException {
writeLock.lock();
if (!isConnected()) {
return;
}

lock.writeLock().lock();
try {
// check again to prevent race
if (!isConnected()) {
return;
}

packetReaderThread.stop();

if (socket.getInputStream() != null) {
socket.getInputStream().close();
}

if (output != null) {
output.close();
output = null;
}

if (socket != null) {
socket.close();
socket = null;
}
} finally {
writeLock.unlock();
lock.writeLock().unlock();
}
}

@Override
public boolean isConnected() {
return (socket != null) && socket.isConnected() && !socket.isClosed();
lock.readLock().lock();
try {
return (socket != null) && socket.isConnected() && !socket.isClosed();
} finally {
lock.readLock().unlock();
}
}

public void setSocketFactory(SocketFactory socketFactory) {
Expand Down

0 comments on commit 537d692

Please sign in to comment.