diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index bc53a732..f8ad845a 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -170,11 +170,22 @@ private static void ensureReceivedMatchesExpected(Message got, Message expected) private void sendKexInit() throws TransportException { log.debug("Sending SSH_MSG_KEXINIT"); - clientProposal = new Proposal(transport.getConfig()); + List knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort()); + clientProposal = new Proposal(transport.getConfig(), knownHostAlgs); transport.write(clientProposal.getPacket()); kexInitSent.set(); } + private List findKnownHostAlgs(String hostname, int port) { + for (HostKeyVerifier hkv : hostVerifiers) { + List keyTypes = hkv.findExistingAlgorithms(hostname, port); + if (keyTypes != null && !keyTypes.isEmpty()) { + return keyTypes; + } + } + return Collections.emptyList(); + } + private void sendNewKeys() throws TransportException { log.debug("Sending SSH_MSG_NEWKEYS"); diff --git a/src/main/java/net/schmizz/sshj/transport/Proposal.java b/src/main/java/net/schmizz/sshj/transport/Proposal.java index ccc5a528..299cd57d 100644 --- a/src/main/java/net/schmizz/sshj/transport/Proposal.java +++ b/src/main/java/net/schmizz/sshj/transport/Proposal.java @@ -38,9 +38,9 @@ class Proposal { private final List s2cComp; private final SSHPacket packet; - public Proposal(Config config) { + public Proposal(Config config, List knownHostAlgs) { kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories()); - sig = Factory.Named.Util.getNames(config.getKeyAlgorithms()); + sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs); c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories()); c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories()); c2sComp = s2cComp = Factory.Named.Util.getNames(config.getCompressionFactories()); @@ -127,32 +127,43 @@ public SSHPacket getPacket() { public NegotiatedAlgorithms negotiate(Proposal other) throws TransportException { return new NegotiatedAlgorithms( - firstMatch("KeyExchangeAlgorithms", - this.getKeyExchangeAlgorithms(), - other.getKeyExchangeAlgorithms()), - firstMatch("HostKeyAlgorithms", - this.getHostKeyAlgorithms(), - other.getHostKeyAlgorithms()), - firstMatch("Client2ServerCipherAlgorithms", - this.getClient2ServerCipherAlgorithms(), - other.getClient2ServerCipherAlgorithms()), - firstMatch("Server2ClientCipherAlgorithms", - this.getServer2ClientCipherAlgorithms(), - other.getServer2ClientCipherAlgorithms()), - firstMatch("Client2ServerMACAlgorithms", - this.getClient2ServerMACAlgorithms(), - other.getClient2ServerMACAlgorithms()), - firstMatch("Server2ClientMACAlgorithms", - this.getServer2ClientMACAlgorithms(), - other.getServer2ClientMACAlgorithms()), - firstMatch("Client2ServerCompressionAlgorithms", - this.getClient2ServerCompressionAlgorithms(), - other.getClient2ServerCompressionAlgorithms()), - firstMatch("Server2ClientCompressionAlgorithms", - this.getServer2ClientCompressionAlgorithms(), - other.getServer2ClientCompressionAlgorithms()), - other.getHostKeyAlgorithms().containsAll(KeyAlgorithms.SSH_RSA_SHA2_ALGORITHMS) - ); + firstMatch("KeyExchangeAlgorithms", this.getKeyExchangeAlgorithms(), other.getKeyExchangeAlgorithms()), + firstMatch("HostKeyAlgorithms", this.getHostKeyAlgorithms(), other.getHostKeyAlgorithms()), + firstMatch("Client2ServerCipherAlgorithms", this.getClient2ServerCipherAlgorithms(), + other.getClient2ServerCipherAlgorithms()), + firstMatch("Server2ClientCipherAlgorithms", this.getServer2ClientCipherAlgorithms(), + other.getServer2ClientCipherAlgorithms()), + firstMatch("Client2ServerMACAlgorithms", this.getClient2ServerMACAlgorithms(), + other.getClient2ServerMACAlgorithms()), + firstMatch("Server2ClientMACAlgorithms", this.getServer2ClientMACAlgorithms(), + other.getServer2ClientMACAlgorithms()), + firstMatch("Client2ServerCompressionAlgorithms", this.getClient2ServerCompressionAlgorithms(), + other.getClient2ServerCompressionAlgorithms()), + firstMatch("Server2ClientCompressionAlgorithms", this.getServer2ClientCompressionAlgorithms(), + other.getServer2ClientCompressionAlgorithms()), + other.getHostKeyAlgorithms().containsAll(KeyAlgorithms.SSH_RSA_SHA2_ALGORITHMS)); + } + + private List filterKnownHostKeyAlgorithms(List configuredKeyAlgorithms, List knownHostKeyAlgorithms) { + if (knownHostKeyAlgorithms != null && !knownHostKeyAlgorithms.isEmpty()) { + List preferredAlgorithms = new ArrayList(); + List otherAlgorithms = new ArrayList(); + + for (String configuredKeyAlgorithm : configuredKeyAlgorithms) { + if (knownHostKeyAlgorithms.contains(configuredKeyAlgorithm)) { + preferredAlgorithms.add(configuredKeyAlgorithm); + } else { + otherAlgorithms.add(configuredKeyAlgorithm); + } + } + + preferredAlgorithms.addAll(otherAlgorithms); + + return preferredAlgorithms; + } else { + return configuredKeyAlgorithms; + } + } private static String firstMatch(String ofWhat, List a, List b) diff --git a/src/main/java/net/schmizz/sshj/transport/verification/FingerprintVerifier.java b/src/main/java/net/schmizz/sshj/transport/verification/FingerprintVerifier.java index 58656bf5..d0a3cc58 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/FingerprintVerifier.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/FingerprintVerifier.java @@ -20,6 +20,8 @@ import java.security.MessageDigest; import java.security.PublicKey; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.regex.Pattern; import net.schmizz.sshj.common.Base64; @@ -74,6 +76,11 @@ public static HostKeyVerifier getInstance(String fingerprint) { public boolean verify(String h, int p, PublicKey k) { return SecurityUtils.getFingerprint(k).equals(md5); } + + @Override + public List findExistingAlgorithms(String hostname, int port) { + return Collections.emptyList(); + } }); } catch (SSHRuntimeException e) { throw e; @@ -120,8 +127,13 @@ public boolean verify(String hostname, int port, PublicKey key) { return Arrays.equals(fingerprintData, digestData); } + @Override + public List findExistingAlgorithms(String hostname, int port) { + return Collections.emptyList(); + } + @Override public String toString() { return "FingerprintVerifier{digestAlgorithm='" + digestAlgorithm + "'}"; } -} \ No newline at end of file +} diff --git a/src/main/java/net/schmizz/sshj/transport/verification/HostKeyVerifier.java b/src/main/java/net/schmizz/sshj/transport/verification/HostKeyVerifier.java index bcfc54d7..e53902f7 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/HostKeyVerifier.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/HostKeyVerifier.java @@ -16,6 +16,7 @@ package net.schmizz.sshj.transport.verification; import java.security.PublicKey; +import java.util.List; /** Host key verification interface. */ public interface HostKeyVerifier { @@ -35,4 +36,12 @@ public interface HostKeyVerifier { */ boolean verify(String hostname, int port, PublicKey key); + /** + * It is necessary to connect with the type of algorithm that matches an existing know_host entry. + * This will allow a match when we later verify with the negotiated key {@code HostKeyVerifier.verify} + * @param hostname remote hostname + * @param port remote port + * @return existing key types or empty list if no keys known for hostname + */ + List findExistingAlgorithms(String hostname, int port); } diff --git a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java index 8f38472a..7c271d62 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java @@ -90,6 +90,10 @@ private void readEntries(BufferedReader br) throws IOException { } } + private String adjustHostname(final String hostname, final int port) { + String lowerHN = hostname.toLowerCase(); + return (port != 22) ? "[" + lowerHN + "]:" + port : lowerHN; + } public File getFile() { return khFile; @@ -103,7 +107,7 @@ public boolean verify(final String hostname, final int port, final PublicKey key return false; } - final String adjustedHostname = (port != 22) ? "[" + hostname + "]:" + port : hostname; + final String adjustedHostname = adjustHostname(hostname, port); boolean foundApplicableHostEntry = false; for (KnownHostEntry e : entries) { @@ -127,6 +131,22 @@ public boolean verify(final String hostname, final int port, final PublicKey key return hostKeyUnverifiableAction(adjustedHostname, key); } + @Override + public List findExistingAlgorithms(String hostname, int port) { + final String adjustedHostname = adjustHostname(hostname, port); + List knownHostAlgorithms = new ArrayList(); + for (KnownHostEntry e : entries) { + try { + if (e.appliesTo(adjustedHostname)) { + knownHostAlgorithms.add(e.getType().toString()); + } + } catch (IOException ioe) { + } + } + + return knownHostAlgorithms; + } + protected boolean hostKeyUnverifiableAction(String hostname, PublicKey key) { return false; } diff --git a/src/main/java/net/schmizz/sshj/transport/verification/PromiscuousVerifier.java b/src/main/java/net/schmizz/sshj/transport/verification/PromiscuousVerifier.java index a95cc41f..c673fd7e 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/PromiscuousVerifier.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/PromiscuousVerifier.java @@ -16,6 +16,8 @@ package net.schmizz.sshj.transport.verification; import java.security.PublicKey; +import java.util.Collections; +import java.util.List; public final class PromiscuousVerifier implements HostKeyVerifier { @@ -25,4 +27,9 @@ public boolean verify(String hostname, int port, PublicKey key) { return true; } + @Override + public List findExistingAlgorithms(String hostname, int port) { + return Collections.emptyList(); + } + }