Skip to content

Commit

Permalink
Prefer known algorithm for known host (#721)
Browse files Browse the repository at this point in the history
* Prefer known algorithm for known host

(#642, #635... 10? issues)

Try to find the Algorithm that was used when a known_host
entry was created and make that the first choice for the
current connection attempt.

If the current connection algorithm matches the
algorithm used when the known_host entry was created
we can get a fair verification.

* Add support for multiple matching hostkeys, in configuration order

Co-authored-by: Bernie Day <[email protected]>
Co-authored-by: Jeroen van Erp <[email protected]>
  • Loading branch information
3 people authored Sep 23, 2021
1 parent 753e3a5 commit 14bf93e
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 31 deletions.
13 changes: 12 additions & 1 deletion src/main/java/net/schmizz/sshj/transport/KeyExchanger.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,22 @@ private static void ensureReceivedMatchesExpected(Message got, Message expected)
private void sendKexInit()
throws TransportException {
log.debug("Sending SSH_MSG_KEXINIT");
clientProposal = new Proposal(transport.getConfig());
List<String> knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort());
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs);
transport.write(clientProposal.getPacket());
kexInitSent.set();
}

private List<String> findKnownHostAlgs(String hostname, int port) {
for (HostKeyVerifier hkv : hostVerifiers) {
List<String> keyTypes = hkv.findExistingAlgorithms(hostname, port);
if (keyTypes != null && !keyTypes.isEmpty()) {
return keyTypes;
}
}
return Collections.emptyList();
}

private void sendNewKeys()
throws TransportException {
log.debug("Sending SSH_MSG_NEWKEYS");
Expand Down
67 changes: 39 additions & 28 deletions src/main/java/net/schmizz/sshj/transport/Proposal.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ class Proposal {
private final List<String> s2cComp;
private final SSHPacket packet;

public Proposal(Config config) {
public Proposal(Config config, List<String> knownHostAlgs) {
kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories());
sig = Factory.Named.Util.getNames(config.getKeyAlgorithms());
sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs);
c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories());
c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories());
c2sComp = s2cComp = Factory.Named.Util.getNames(config.getCompressionFactories());
Expand Down Expand Up @@ -127,32 +127,43 @@ public SSHPacket getPacket() {
public NegotiatedAlgorithms negotiate(Proposal other)
throws TransportException {
return new NegotiatedAlgorithms(
firstMatch("KeyExchangeAlgorithms",
this.getKeyExchangeAlgorithms(),
other.getKeyExchangeAlgorithms()),
firstMatch("HostKeyAlgorithms",
this.getHostKeyAlgorithms(),
other.getHostKeyAlgorithms()),
firstMatch("Client2ServerCipherAlgorithms",
this.getClient2ServerCipherAlgorithms(),
other.getClient2ServerCipherAlgorithms()),
firstMatch("Server2ClientCipherAlgorithms",
this.getServer2ClientCipherAlgorithms(),
other.getServer2ClientCipherAlgorithms()),
firstMatch("Client2ServerMACAlgorithms",
this.getClient2ServerMACAlgorithms(),
other.getClient2ServerMACAlgorithms()),
firstMatch("Server2ClientMACAlgorithms",
this.getServer2ClientMACAlgorithms(),
other.getServer2ClientMACAlgorithms()),
firstMatch("Client2ServerCompressionAlgorithms",
this.getClient2ServerCompressionAlgorithms(),
other.getClient2ServerCompressionAlgorithms()),
firstMatch("Server2ClientCompressionAlgorithms",
this.getServer2ClientCompressionAlgorithms(),
other.getServer2ClientCompressionAlgorithms()),
other.getHostKeyAlgorithms().containsAll(KeyAlgorithms.SSH_RSA_SHA2_ALGORITHMS)
);
firstMatch("KeyExchangeAlgorithms", this.getKeyExchangeAlgorithms(), other.getKeyExchangeAlgorithms()),
firstMatch("HostKeyAlgorithms", this.getHostKeyAlgorithms(), other.getHostKeyAlgorithms()),
firstMatch("Client2ServerCipherAlgorithms", this.getClient2ServerCipherAlgorithms(),
other.getClient2ServerCipherAlgorithms()),
firstMatch("Server2ClientCipherAlgorithms", this.getServer2ClientCipherAlgorithms(),
other.getServer2ClientCipherAlgorithms()),
firstMatch("Client2ServerMACAlgorithms", this.getClient2ServerMACAlgorithms(),
other.getClient2ServerMACAlgorithms()),
firstMatch("Server2ClientMACAlgorithms", this.getServer2ClientMACAlgorithms(),
other.getServer2ClientMACAlgorithms()),
firstMatch("Client2ServerCompressionAlgorithms", this.getClient2ServerCompressionAlgorithms(),
other.getClient2ServerCompressionAlgorithms()),
firstMatch("Server2ClientCompressionAlgorithms", this.getServer2ClientCompressionAlgorithms(),
other.getServer2ClientCompressionAlgorithms()),
other.getHostKeyAlgorithms().containsAll(KeyAlgorithms.SSH_RSA_SHA2_ALGORITHMS));
}

private List<String> filterKnownHostKeyAlgorithms(List<String> configuredKeyAlgorithms, List<String> knownHostKeyAlgorithms) {
if (knownHostKeyAlgorithms != null && !knownHostKeyAlgorithms.isEmpty()) {
List<String> preferredAlgorithms = new ArrayList<String>();
List<String> otherAlgorithms = new ArrayList<String>();

for (String configuredKeyAlgorithm : configuredKeyAlgorithms) {
if (knownHostKeyAlgorithms.contains(configuredKeyAlgorithm)) {
preferredAlgorithms.add(configuredKeyAlgorithm);
} else {
otherAlgorithms.add(configuredKeyAlgorithm);
}
}

preferredAlgorithms.addAll(otherAlgorithms);

return preferredAlgorithms;
} else {
return configuredKeyAlgorithms;
}

}

private static String firstMatch(String ofWhat, List<String> a, List<String> b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import java.security.MessageDigest;
import java.security.PublicKey;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.regex.Pattern;

import net.schmizz.sshj.common.Base64;
Expand Down Expand Up @@ -74,6 +76,11 @@ public static HostKeyVerifier getInstance(String fingerprint) {
public boolean verify(String h, int p, PublicKey k) {
return SecurityUtils.getFingerprint(k).equals(md5);
}

@Override
public List<String> findExistingAlgorithms(String hostname, int port) {
return Collections.emptyList();
}
});
} catch (SSHRuntimeException e) {
throw e;
Expand Down Expand Up @@ -120,8 +127,13 @@ public boolean verify(String hostname, int port, PublicKey key) {
return Arrays.equals(fingerprintData, digestData);
}

@Override
public List<String> findExistingAlgorithms(String hostname, int port) {
return Collections.emptyList();
}

@Override
public String toString() {
return "FingerprintVerifier{digestAlgorithm='" + digestAlgorithm + "'}";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package net.schmizz.sshj.transport.verification;

import java.security.PublicKey;
import java.util.List;

/** Host key verification interface. */
public interface HostKeyVerifier {
Expand All @@ -35,4 +36,12 @@ public interface HostKeyVerifier {
*/
boolean verify(String hostname, int port, PublicKey key);

/**
* It is necessary to connect with the type of algorithm that matches an existing know_host entry.
* This will allow a match when we later verify with the negotiated key {@code HostKeyVerifier.verify}
* @param hostname remote hostname
* @param port remote port
* @return existing key types or empty list if no keys known for hostname
*/
List<String> findExistingAlgorithms(String hostname, int port);
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ private void readEntries(BufferedReader br) throws IOException {
}
}

private String adjustHostname(final String hostname, final int port) {
String lowerHN = hostname.toLowerCase();
return (port != 22) ? "[" + lowerHN + "]:" + port : lowerHN;
}

public File getFile() {
return khFile;
Expand All @@ -103,7 +107,7 @@ public boolean verify(final String hostname, final int port, final PublicKey key
return false;
}

final String adjustedHostname = (port != 22) ? "[" + hostname + "]:" + port : hostname;
final String adjustedHostname = adjustHostname(hostname, port);

boolean foundApplicableHostEntry = false;
for (KnownHostEntry e : entries) {
Expand All @@ -127,6 +131,22 @@ public boolean verify(final String hostname, final int port, final PublicKey key
return hostKeyUnverifiableAction(adjustedHostname, key);
}

@Override
public List<String> findExistingAlgorithms(String hostname, int port) {
final String adjustedHostname = adjustHostname(hostname, port);
List<String> knownHostAlgorithms = new ArrayList<String>();
for (KnownHostEntry e : entries) {
try {
if (e.appliesTo(adjustedHostname)) {
knownHostAlgorithms.add(e.getType().toString());
}
} catch (IOException ioe) {
}
}

return knownHostAlgorithms;
}

protected boolean hostKeyUnverifiableAction(String hostname, PublicKey key) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package net.schmizz.sshj.transport.verification;

import java.security.PublicKey;
import java.util.Collections;
import java.util.List;

public final class PromiscuousVerifier
implements HostKeyVerifier {
Expand All @@ -25,4 +27,9 @@ public boolean verify(String hostname, int port, PublicKey key) {
return true;
}

@Override
public List<String> findExistingAlgorithms(String hostname, int port) {
return Collections.emptyList();
}

}

0 comments on commit 14bf93e

Please sign in to comment.