diff --git a/src/main/java/com/hierynomus/sshj/socket/Sockets.java b/src/main/java/com/hierynomus/sshj/socket/Sockets.java new file mode 100644 index 000000000..945869f95 --- /dev/null +++ b/src/main/java/com/hierynomus/sshj/socket/Sockets.java @@ -0,0 +1,26 @@ +package com.hierynomus.sshj.socket; + +import java.io.Closeable; +import java.io.IOException; +import java.net.Socket; + +public class Sockets { + + /** + * Java 7 and up have Socket implemented as Closeable, whereas Java6 did not have this inheritance. + * @param socket The socket to wrap as Closeable + * @return + */ + public static Closeable asCloseable(final Socket socket) { + if (Closeable.class.isAssignableFrom(socket.getClass())) { + return Closeable.class.cast(socket); + } else { + return new Closeable() { + @Override + public void close() throws IOException { + socket.close(); + } + }; + } + } +} diff --git a/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java b/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java index a6c524df5..be9fb7c14 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.connection.channel; +import com.hierynomus.sshj.socket.Sockets; import net.schmizz.concurrent.Event; import net.schmizz.sshj.common.IOUtils; @@ -23,6 +24,8 @@ import java.net.Socket; import java.util.concurrent.TimeUnit; +import static com.hierynomus.sshj.socket.Sockets.asCloseable; + public class SocketStreamCopyMonitor extends Thread { @@ -32,16 +35,6 @@ private SocketStreamCopyMonitor(Runnable r) { setDaemon(true); } - private static Closeable wrapSocket(final Socket socket) { - return new Closeable() { - @Override - public void close() - throws IOException { - socket.close(); - } - }; - } - public static void monitor(final int frequency, final TimeUnit unit, final Event x, final Event y, final Channel channel, final Socket socket) { @@ -54,7 +47,7 @@ public void run() { } } catch (IOException ignored) { } finally { - IOUtils.closeQuietly(channel, wrapSocket(socket)); + IOUtils.closeQuietly(channel, asCloseable(socket)); } } }).start(); diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java index 9ea10bf51..0c0dfb824 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java @@ -16,12 +16,12 @@ package net.schmizz.sshj.connection.channel.direct; import net.schmizz.concurrent.Event; +import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.StreamCopier; import net.schmizz.sshj.connection.Connection; -import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor; -import net.schmizz.sshj.transport.TransportException; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,6 +30,8 @@ import java.net.Socket; import java.util.concurrent.TimeUnit; +import static com.hierynomus.sshj.socket.Sockets.asCloseable; + public class LocalPortForwarder { public static class Parameters { @@ -112,11 +114,15 @@ public LocalPortForwarder(Connection conn, Parameters parameters, ServerSocket s this.serverSocket = serverSocket; } - protected DirectTCPIPChannel openChannel(Socket socket) - throws TransportException, ConnectionException { - final DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters); - chan.open(); - return chan; + private void startChannel(Socket socket) throws IOException { + DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters); + try { + chan.open(); + chan.start(); + } catch (IOException e) { + IOUtils.closeQuietly(chan, asCloseable(socket)); + throw e; + } } /** @@ -130,7 +136,7 @@ public void listen() while (!Thread.currentThread().isInterrupted()) { final Socket socket = serverSocket.accept(); log.debug("Got connection from {}", socket.getRemoteSocketAddress()); - openChannel(socket).start(); + startChannel(socket); } log.debug("Interrupted!"); }