diff --git a/src/main/java/com/jcraft/jsch/Session.java b/src/main/java/com/jcraft/jsch/Session.java index 97b3d708..433554b1 100644 --- a/src/main/java/com/jcraft/jsch/Session.java +++ b/src/main/java/com/jcraft/jsch/Session.java @@ -93,6 +93,8 @@ public class Session { private byte[] MACc2s; private byte[] MACs2c; + // RFC 4253 6.4. each direction must run independently hence we have an incoming and outgoing + // sequence number private int seqi = 0; private int seqo = 0; @@ -1998,6 +2000,13 @@ public void disconnect() { } io = null; socket = null; + + // RFC 4253 6.4. the 'sequence_number' is never reset, even if keys/algorithms are renegotiated + // later. Hence we only reset these on session disconnect as the sequence has to start at zero + // for the first packet during (re)connect. + seqi = 0; + seqo = 0; + // synchronized(jsch.pool){ // jsch.pool.removeElement(this); // } diff --git a/src/test/java/com/jcraft/jsch/SessionReconnectIT.java b/src/test/java/com/jcraft/jsch/SessionReconnectIT.java new file mode 100644 index 00000000..f2b603ef --- /dev/null +++ b/src/test/java/com/jcraft/jsch/SessionReconnectIT.java @@ -0,0 +1,160 @@ +package com.jcraft.jsch; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import com.github.valfirst.slf4jtest.LoggingEvent; +import com.github.valfirst.slf4jtest.TestLogger; +import com.github.valfirst.slf4jtest.TestLoggerFactory; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Base64; +import java.util.List; +import java.util.Random; +import org.apache.commons.codec.digest.DigestUtils; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.output.Slf4jLogConsumer; +import org.testcontainers.images.builder.ImageFromDockerfile; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +@Testcontainers +public class SessionReconnectIT { + + private static final int timeout = 2000; + private static final DigestUtils sha256sum = new DigestUtils(DigestUtils.getSha256Digest()); + private static final TestLogger jschLogger = TestLoggerFactory.getTestLogger(JSch.class); + private static final TestLogger sshdLogger = + TestLoggerFactory.getTestLogger(SessionReconnectIT.class); + + @TempDir + public Path tmpDir; + private Path in; + private Path out; + private String hash; + private Slf4jLogConsumer sshdLogConsumer; + + @Container + public GenericContainer sshd = new GenericContainer<>(new ImageFromDockerfile() + .withFileFromClasspath("dropbear_rsa_host_key", "docker/dropbear_rsa_host_key") + .withFileFromClasspath("authorized_keys", "docker/authorized_keys") + .withFileFromClasspath("Dockerfile", "docker/Dockerfile.dropbear")).withExposedPorts(22); + + @BeforeAll + public static void beforeAll() { + JSch.setLogger(new Slf4jLogger()); + } + + @BeforeEach + public void beforeEach() throws IOException { + if (sshdLogConsumer == null) { + sshdLogConsumer = new Slf4jLogConsumer(sshdLogger); + sshd.followOutput(sshdLogConsumer); + } + + in = tmpDir.resolve("in"); + out = tmpDir.resolve("out"); + Files.createFile(in); + try (OutputStream os = Files.newOutputStream(in)) { + byte[] data = new byte[1024]; + for (int i = 0; i < 1024 * 100; i += 1024) { + new Random().nextBytes(data); + os.write(data); + } + } + hash = sha256sum.digestAsHex(in); + + jschLogger.clearAll(); + sshdLogger.clearAll(); + } + + @AfterAll + public static void afterAll() { + JSch.setLogger(null); + jschLogger.clearAll(); + sshdLogger.clearAll(); + } + + @Test + public void testReconnectWithExtraAlgorithms() throws Exception { + JSch ssh = createRSAIdentity(); + Session session = createSession(ssh); + try { + doSftp(session, false); + fail("exception expected"); + } catch (JSchAlgoNegoFailException e) { + // Dropbear does not support rsa-sha2-512/rsa-sha2-256, so add ssh-rsa + String serverHostKey = session.getConfig("server_host_key") + ",ssh-rsa"; + String pubkeyAcceptedAlgorithms = session.getConfig("PubkeyAcceptedAlgorithms") + ",ssh-rsa"; + session.setConfig("server_host_key", serverHostKey); + session.setConfig("PubkeyAcceptedAlgorithms", pubkeyAcceptedAlgorithms); + doSftp(session, true); + } + } + + private JSch createRSAIdentity() throws Exception { + HostKey hostKey = readHostKey(getResourceFile("docker/ssh_host_rsa_key.pub")); + JSch ssh = new JSch(); + ssh.addIdentity(getResourceFile("docker/id_rsa"), getResourceFile("docker/id_rsa.pub"), null); + ssh.getHostKeyRepository().add(hostKey, null); + return ssh; + } + + private HostKey readHostKey(String fileName) throws Exception { + List lines = Files.readAllLines(Paths.get(fileName), UTF_8); + String[] split = lines.get(0).split("\\s+"); + String hostname = String.format("[%s]:%d", sshd.getHost(), sshd.getFirstMappedPort()); + return new HostKey(hostname, Base64.getDecoder().decode(split[1])); + } + + private Session createSession(JSch ssh) throws Exception { + Session session = ssh.getSession("root", sshd.getHost(), sshd.getFirstMappedPort()); + session.setConfig("StrictHostKeyChecking", "yes"); + session.setConfig("PreferredAuthentications", "publickey"); + return session; + } + + private void doSftp(Session session, boolean debugException) throws Exception { + try { + session.setTimeout(timeout); + session.connect(); + ChannelSftp sftp = (ChannelSftp) session.openChannel("sftp"); + sftp.connect(timeout); + sftp.put(in.toString(), "/root/test"); + sftp.get("/root/test", out.toString()); + sftp.disconnect(); + session.disconnect(); + } catch (Exception e) { + if (debugException) { + printInfo(); + } + throw e; + } + + assertEquals(1024L * 100L, Files.size(out)); + assertEquals(hash, sha256sum.digestAsHex(out)); + } + + private void printInfo() { + jschLogger.getAllLoggingEvents().stream().map(LoggingEvent::getFormattedMessage) + .forEach(System.out::println); + sshdLogger.getAllLoggingEvents().stream().map(LoggingEvent::getFormattedMessage) + .forEach(System.out::println); + System.out.println(""); + System.out.println(""); + System.out.println(""); + } + + private String getResourceFile(String fileName) { + return ResourceUtil.getResourceFile(getClass(), fileName); + } +}