From ec46a7a48983ff6f2bb758a161f2fc0aab51aee7 Mon Sep 17 00:00:00 2001 From: Jeroen van Erp Date: Fri, 29 Sep 2017 13:23:21 +0200 Subject: [PATCH] Fix decoding signature bytes (Fixes #355, #354) (#361) * Fix for signature verify in DSA * Cleaned up signature verification * Fixed import * Ignored erroneous pmd warnings * Updated JavaDoc --- .../sshj/signature/SignatureEdDSA.java | 67 ++-------- .../sshj/signature/AbstractSignature.java | 64 +++++----- .../schmizz/sshj/signature/SignatureDSA.java | 53 ++++---- .../sshj/signature/SignatureECDSA.java | 29 +---- .../schmizz/sshj/signature/SignatureRSA.java | 2 +- .../sshj/signature/SignatureDSASpec.groovy | 116 ++++++++++++++++++ 6 files changed, 193 insertions(+), 138 deletions(-) create mode 100644 src/test/groovy/net/schmizz/sshj/signature/SignatureDSASpec.groovy diff --git a/src/main/java/com/hierynomus/sshj/signature/SignatureEdDSA.java b/src/main/java/com/hierynomus/sshj/signature/SignatureEdDSA.java index 726e1b0f2..19a1da4ab 100644 --- a/src/main/java/com/hierynomus/sshj/signature/SignatureEdDSA.java +++ b/src/main/java/com/hierynomus/sshj/signature/SignatureEdDSA.java @@ -16,14 +16,16 @@ package com.hierynomus.sshj.signature; import net.i2p.crypto.eddsa.EdDSAEngine; -import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.KeyType; import net.schmizz.sshj.common.SSHRuntimeException; +import net.schmizz.sshj.signature.AbstractSignature; import net.schmizz.sshj.signature.Signature; -import java.security.*; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SignatureException; -public class SignatureEdDSA implements Signature { +public class SignatureEdDSA extends AbstractSignature { public static class Factory implements net.schmizz.sshj.common.Factory.Named { @Override @@ -37,53 +39,14 @@ public Signature create() { } } - final EdDSAEngine engine; - - protected SignatureEdDSA() { - try { - engine = new EdDSAEngine(MessageDigest.getInstance("SHA-512")); - } catch (NoSuchAlgorithmException e) { - throw new SSHRuntimeException(e); - } - } - - @Override - public void initVerify(PublicKey pubkey) { - try { - engine.initVerify(pubkey); - } catch (InvalidKeyException e) { - throw new SSHRuntimeException(e); - } + SignatureEdDSA() { + super(getEngine()); } - @Override - public void initSign(PrivateKey prvkey) { + private static EdDSAEngine getEngine() { try { - engine.initSign(prvkey); - } catch (InvalidKeyException e) { - throw new SSHRuntimeException(e); - } - } - - @Override - public void update(byte[] H) { - update(H, 0, H.length); - } - - @Override - public void update(byte[] H, int off, int len) { - try { - engine.update(H, off, len); - } catch (SignatureException e) { - throw new SSHRuntimeException(e); - } - } - - @Override - public byte[] sign() { - try { - return engine.sign(); - } catch (SignatureException e) { + return new EdDSAEngine(MessageDigest.getInstance("SHA-512")); + } catch (NoSuchAlgorithmException e) { throw new SSHRuntimeException(e); } } @@ -96,17 +59,9 @@ public byte[] encode(byte[] signature) { @Override public boolean verify(byte[] sig) { try { - Buffer.PlainBuffer plainBuffer = new Buffer.PlainBuffer(sig); - String algo = plainBuffer.readString(); - if (!"ssh-ed25519".equals(algo)) { - throw new SSHRuntimeException("Expected 'ssh-ed25519' key algorithm, but was: " + algo); - } - byte[] bytes = plainBuffer.readBytes(); - return engine.verify(bytes); + return signature.verify(extractSig(sig, "ssh-ed25519")); } catch (SignatureException e) { throw new SSHRuntimeException(e); - } catch (Buffer.BufferException e) { - throw new SSHRuntimeException(e); } } } diff --git a/src/main/java/net/schmizz/sshj/signature/AbstractSignature.java b/src/main/java/net/schmizz/sshj/signature/AbstractSignature.java index 2e47c36d3..3eb9ae963 100644 --- a/src/main/java/net/schmizz/sshj/signature/AbstractSignature.java +++ b/src/main/java/net/schmizz/sshj/signature/AbstractSignature.java @@ -15,31 +15,39 @@ */ package net.schmizz.sshj.signature; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.SSHRuntimeException; import net.schmizz.sshj.common.SecurityUtils; -import java.security.GeneralSecurityException; -import java.security.PrivateKey; -import java.security.PublicKey; -import java.security.SignatureException; +import java.security.*; -/** An abstract class for {@link Signature} that implements common functionality. */ +/** + * An abstract class for {@link Signature} that implements common functionality. + */ public abstract class AbstractSignature implements Signature { - protected final String algorithm; - protected java.security.Signature signature; + @SuppressWarnings("PMD.UnnecessaryFullyQualifiedName") + protected final java.security.Signature signature; protected AbstractSignature(String algorithm) { - this.algorithm = algorithm; + try { + this.signature = SecurityUtils.getSignature(algorithm); + } catch (GeneralSecurityException e) { + throw new SSHRuntimeException(e); + } + } + + protected AbstractSignature(@SuppressWarnings("PMD.UnnecessaryFullyQualifiedName") + java.security.Signature signatureEngine) { + this.signature = signatureEngine; } @Override public void initVerify(PublicKey publicKey) { try { - signature = SecurityUtils.getSignature(algorithm); signature.initVerify(publicKey); - } catch (GeneralSecurityException e) { + } catch (InvalidKeyException e) { throw new SSHRuntimeException(e); } } @@ -47,9 +55,8 @@ public void initVerify(PublicKey publicKey) { @Override public void initSign(PrivateKey privateKey) { try { - signature = SecurityUtils.getSignature(algorithm); signature.initSign(privateKey); - } catch (GeneralSecurityException e) { + } catch (InvalidKeyException e) { throw new SSHRuntimeException(e); } } @@ -77,23 +84,24 @@ public byte[] sign() { } } - protected byte[] extractSig(byte[] sig) { - if (sig[0] == 0 && sig[1] == 0 && sig[2] == 0) { - int i = 0; - int j = sig[i++] << 24 & 0xff000000 - | sig[i++] << 16 & 0x00ff0000 - | sig[i++] << 8 & 0x0000ff00 - | sig[i++] & 0x000000ff; - i += j; - j = sig[i++] << 24 & 0xff000000 - | sig[i++] << 16 & 0x00ff0000 - | sig[i++] << 8 & 0x0000ff00 - | sig[i++] & 0x000000ff; - byte[] newSig = new byte[j]; - System.arraycopy(sig, i, newSig, 0, j); - return newSig; + /** + * Check whether the signature is generated using the expected algorithm, and if so, return the signature blob + * + * @param sig The full signature + * @param expectedKeyAlgorithm The expected key algorithm + * @return The blob part of the signature + */ + protected byte[] extractSig(byte[] sig, String expectedKeyAlgorithm) { + Buffer.PlainBuffer buffer = new Buffer.PlainBuffer(sig); + try { + String algo = buffer.readString(); + if (!expectedKeyAlgorithm.equals(algo)) { + throw new SSHRuntimeException("Expected '" + expectedKeyAlgorithm + "' key algorithm, but got: " + algo); + } + return buffer.readBytes(); + } catch (Buffer.BufferException e) { + throw new SSHRuntimeException(e); } - return sig; } } diff --git a/src/main/java/net/schmizz/sshj/signature/SignatureDSA.java b/src/main/java/net/schmizz/sshj/signature/SignatureDSA.java index 12305efc3..9b11b6ca7 100644 --- a/src/main/java/net/schmizz/sshj/signature/SignatureDSA.java +++ b/src/main/java/net/schmizz/sshj/signature/SignatureDSA.java @@ -17,14 +17,23 @@ import net.schmizz.sshj.common.KeyType; import net.schmizz.sshj.common.SSHRuntimeException; +import org.bouncycastle.asn1.*; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.math.BigInteger; import java.security.SignatureException; +import java.util.Arrays; -/** DSA {@link Signature} */ +/** + * DSA {@link Signature} + */ public class SignatureDSA extends AbstractSignature { - /** A named factory for DSA signature */ + /** + * A named factory for DSA signature + */ public static class Factory implements net.schmizz.sshj.common.Factory.Named { @@ -73,33 +82,25 @@ public byte[] encode(byte[] sig) { } @Override - public boolean verify(byte[] sig) { - sig = extractSig(sig); - - // ASN.1 - int frst = (sig[0] & 0x80) != 0 ? 1 : 0; - int scnd = (sig[20] & 0x80) != 0 ? 1 : 0; - - int length = sig.length + 6 + frst + scnd; - byte[] tmp = new byte[length]; - tmp[0] = (byte) 0x30; - tmp[1] = (byte) 0x2c; - tmp[1] += frst; - tmp[1] += scnd; - tmp[2] = (byte) 0x02; - tmp[3] = (byte) 0x14; - tmp[3] += frst; - System.arraycopy(sig, 0, tmp, 4 + frst, 20); - tmp[4 + tmp[3]] = (byte) 0x02; - tmp[5 + tmp[3]] = (byte) 0x14; - tmp[5 + tmp[3]] += scnd; - System.arraycopy(sig, 20, tmp, 6 + tmp[3] + scnd, 20); - sig = tmp; - + public boolean verify(byte[] incomingSig) { + byte[] extractSig = extractSig(incomingSig, "ssh-dss"); try { - return signature.verify(sig); + // ASN.1 + ByteArrayOutputStream os = new ByteArrayOutputStream(); + ASN1OutputStream asn1OutputStream = new ASN1OutputStream(os); + ASN1EncodableVector vector = new ASN1EncodableVector(); + BigInteger bigInteger = new BigInteger(1, Arrays.copyOfRange(extractSig, 0, 20)); + vector.add(new ASN1Integer(bigInteger)); + BigInteger bigInteger2 = new BigInteger(1, Arrays.copyOfRange(extractSig, 20, 40)); + vector.add(new ASN1Integer(bigInteger2)); + asn1OutputStream.writeObject(new DERSequence(vector)); + asn1OutputStream.close(); + byte[] finalSig = os.toByteArray(); + return signature.verify(finalSig); } catch (SignatureException e) { throw new SSHRuntimeException(e); + } catch (IOException e) { + throw new SSHRuntimeException(e); } } diff --git a/src/main/java/net/schmizz/sshj/signature/SignatureECDSA.java b/src/main/java/net/schmizz/sshj/signature/SignatureECDSA.java index c2e37faa3..1a857a340 100644 --- a/src/main/java/net/schmizz/sshj/signature/SignatureECDSA.java +++ b/src/main/java/net/schmizz/sshj/signature/SignatureECDSA.java @@ -110,15 +110,7 @@ public boolean verify(byte[] sig) { byte[] r; byte[] s; try { - Buffer sigbuf = new Buffer.PlainBuffer(sig); - final String algo = new String(sigbuf.readBytes()); - if (!keyTypeName.equals(algo)) { - throw new SSHRuntimeException(String.format("Signature :: " + keyTypeName + " expected, got %s", algo)); - } - final int rsLen = sigbuf.readUInt32AsInt(); - if (sigbuf.available() != rsLen) { - throw new SSHRuntimeException("Invalid key length"); - } + Buffer sigbuf = new Buffer.PlainBuffer(extractSig(sig, keyTypeName)); r = sigbuf.readBytes(); s = sigbuf.readBytes(); } catch (Exception e) { @@ -135,28 +127,11 @@ public boolean verify(byte[] sig) { } private byte[] asnEncode(byte[] r, byte[] s) throws IOException { - int rLen = r.length; - int sLen = s.length; - - /* - * We can't have the high bit set, so add an extra zero at the beginning - * if so. - */ - if ((r[0] & 0x80) != 0) { - rLen++; - } - if ((s[0] & 0x80) != 0) { - sLen++; - } - - /* Calculate total output length */ - int length = 6 + rLen + sLen; - ASN1EncodableVector vector = new ASN1EncodableVector(); vector.add(new ASN1Integer(r)); vector.add(new ASN1Integer(s)); - ByteArrayOutputStream baos = new ByteArrayOutputStream(length); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); ASN1OutputStream asnOS = new ASN1OutputStream(baos); asnOS.writeObject(new DERSequence(vector)); diff --git a/src/main/java/net/schmizz/sshj/signature/SignatureRSA.java b/src/main/java/net/schmizz/sshj/signature/SignatureRSA.java index 6ed5f1bbd..16f33dbc9 100644 --- a/src/main/java/net/schmizz/sshj/signature/SignatureRSA.java +++ b/src/main/java/net/schmizz/sshj/signature/SignatureRSA.java @@ -51,7 +51,7 @@ public byte[] encode(byte[] signature) { @Override public boolean verify(byte[] sig) { - sig = extractSig(sig); + sig = extractSig(sig, "ssh-rsa"); try { return signature.verify(sig); } catch (SignatureException e) { diff --git a/src/test/groovy/net/schmizz/sshj/signature/SignatureDSASpec.groovy b/src/test/groovy/net/schmizz/sshj/signature/SignatureDSASpec.groovy new file mode 100644 index 000000000..8913b7f5c --- /dev/null +++ b/src/test/groovy/net/schmizz/sshj/signature/SignatureDSASpec.groovy @@ -0,0 +1,116 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* +* Copyright (C)2009 - SSHJ Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package net.schmizz.sshj.signature + +import spock.lang.Unroll; + +import java.math.BigInteger; +import java.security.KeyFactory; +import java.security.NoSuchAlgorithmException; +import java.security.PublicKey; +import java.security.spec.DSAPublicKeySpec; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import spock.lang.Specification + +class SignatureDSASpec extends Specification { + + def keyFactory = KeyFactory.getInstance("DSA") + + private PublicKey createPublicKey(final byte[] y, final byte[] p, final byte[] q, final byte[] g) throws Exception { + final BigInteger publicKey = new BigInteger(y); + final BigInteger prime = new BigInteger(p); + final BigInteger subPrime = new BigInteger(q); + final BigInteger base = new BigInteger(g); + final DSAPublicKeySpec dsaPubKeySpec = new DSAPublicKeySpec(publicKey, prime, subPrime, base); + return keyFactory.generatePublic(dsaPubKeySpec); + } + + + @Unroll + def "should verify signature"() { + given: + def signatureDSA = new SignatureDSA() + def publicKey = createPublicKey(y, p, q, g) + signatureDSA.initVerify(publicKey) + + when: + signatureDSA.update(H) + + then: + signatureDSA.verify(H_sig) + + where: + y << [[103, 23, -102, -4, -110, -90, 66, -52, -14, 125, -16, -76, -110, 33, -111, -113, -46, 27, -118, -73, 0, -19, -48, 43, -102, 56, -49, -84, 118, -10, 76, 84, -5, 84, 55, 72, -115, -34, 95, 80, 32, -120, 57, 101, -64, 111, -37, -26, 96, 55, -98, -24, -99, -81, 60, 22, 5, -55, 119, -95, -28, 114, -40, 13, 97, 65, 22, 33, 117, -59, 22, 81, -56, 98, -112, 103, -62, 90, -12, 81, 61, -67, 104, -24, 67, -18, -60, 78, -127, 44, 13, 11, -117, -118, -69, 89, -25, 26, 103, 72, -83, 114, -40, -124, -10, -31, -34, -49, -54, -15, 92, 79, -40, 14, -12, 58, -112, -30, 11, 48, 26, 121, 105, -68, 92, -93, 99, -78] as byte[], + [0, -92, 59, 5, 72, 124, 101, 124, -18, 114, 7, 100, 98, -61, 73, -104, 120, -98, 54, 118, 17, -62, 91, -110, 29, 98, 50, -101, -41, 99, -116, 101, 107, -123, 124, -97, 62, 119, 88, -109, -110, -1, 109, 119, -51, 69, -98, -105, 2, -69, -121, -82, -118, 23, -6, 96, -61, -65, 102, -58, -74, 32, -104, 116, -6, -35, -83, -10, -88, -68, 106, -112, 72, -2, 35, 38, 15, -11, -22, 30, -114, -46, -47, -18, -17, -71, 24, -25, 28, 13, 29, -40, 101, 18, 81, 45, -120, -67, -53, -41, 11, 50, -89, -33, 50, 54, -14, -91, -35, 12, -42, 13, -84, -19, 100, -3, -85, -18, 74, 99, -49, 64, -49, 51, -83, -82, -127, 116, 64] as byte[]] + p << [[0, -3, 127, 83, -127, 29, 117, 18, 41, 82, -33, 74, -100, 46, -20, -28, -25, -10, 17, -73, 82, 60, -17, 68, + 0, -61, 30, 63, -128, -74, 81, 38, 105, 69, 93, 64, 34, 81, -5, 89, 61, -115, 88, -6, -65, -59, -11, -70, + 48, -10, -53, -101, 85, 108, -41, -127, 59, -128, 29, 52, 111, -14, 102, 96, -73, 107, -103, 80, -91, -92, + -97, -97, -24, 4, 123, 16, 34, -62, 79, -69, -87, -41, -2, -73, -58, 27, -8, 59, 87, -25, -58, -88, -90, 21, + 15, 4, -5, -125, -10, -45, -59, 30, -61, 2, 53, 84, 19, 90, 22, -111, 50, -10, 117, -13, -82, 43, 97, -41, + 42, -17, -14, 34, 3, 25, -99, -47, 72, 1, -57] as byte[], + [0, -3, 127, 83, -127, 29, 117, 18, 41, 82, -33, 74, -100, 46, -20, -28, -25, -10, 17, -73, 82, 60, -17, 68, + 0, -61, 30, 63, -128, -74, 81, 38, 105, 69, 93, 64, 34, 81, -5, 89, 61, -115, 88, -6, -65, -59, -11, -70, + 48, -10, -53, -101, 85, 108, -41, -127, 59, -128, 29, 52, 111, -14, 102, 96, -73, 107, -103, 80, -91, -92, + -97, -97, -24, 4, 123, 16, 34, -62, 79, -69, -87, -41, -2, -73, -58, 27, -8, 59, 87, -25, -58, -88, -90, 21, + 15, 4, -5, -125, -10, -45, -59, 30, -61, 2, 53, 84, 19, 90, 22, -111, 50, -10, 117, -13, -82, 43, 97, -41, + 42, -17, -14, 34, 3, 25, -99, -47, 72, 1, -57] as byte[]] + q << [[0, -105, 96, 80, -113, 21, 35, 11, -52, -78, -110, -71, -126, -94, -21, -124, 11, -16, 88, 28, -11] as byte[], + [0, -105, 96, 80, -113, 21, 35, 11, -52, -78, -110, -71, -126, -94, -21, -124, 11, -16, 88, 28, -11] as byte[]] + g << [[0, -9, -31, -96, -123, -42, -101, 61, -34, -53, -68, -85, 92, 54, -72, 87, -71, 121, -108, -81, -69, -6, 58, + -22, -126, -7, 87, 76, 11, 61, 7, -126, 103, 81, 89, 87, -114, -70, -44, 89, 79, -26, 113, 7, 16, -127, + -128, -76, 73, 22, 113, 35, -24, 76, 40, 22, 19, -73, -49, 9, 50, -116, -56, -90, -31, 60, 22, 122, -117, + 84, 124, -115, 40, -32, -93, -82, 30, 43, -77, -90, 117, -111, 110, -93, 127, 11, -6, 33, 53, 98, -15, -5, + 98, 122, 1, 36, 59, -52, -92, -15, -66, -88, 81, -112, -119, -88, -125, -33, -31, 90, -27, -97, 6, -110, + -117, 102, 94, -128, 123, 85, 37, 100, 1, 76, 59, -2, -49, 73, 42] as byte[], + [0, -9, -31, -96, -123, -42, -101, 61, -34, -53, -68, -85, 92, 54, -72, 87, -71, 121, -108, -81, -69, -6, 58, + -22, -126, -7, 87, 76, 11, 61, 7, -126, 103, 81, 89, 87, -114, -70, -44, 89, 79, -26, 113, 7, 16, -127, + -128, -76, 73, 22, 113, 35, -24, 76, 40, 22, 19, -73, -49, 9, 50, -116, -56, -90, -31, 60, 22, 122, -117, + 84, 124, -115, 40, -32, -93, -82, 30, 43, -77, -90, 117, -111, 110, -93, 127, 11, -6, 33, 53, 98, -15, -5, + 98, 122, 1, 36, 59, -52, -92, -15, -66, -88, 81, -112, -119, -88, -125, -33, -31, 90, -27, -97, 6, -110, + -117, 102, 94, -128, 123, 85, 37, 100, 1, 76, 59, -2, -49, 73, 42] as byte[]] + H << [[-13, 20, 103, 73, 115, -68, 113, 74, -25, 12, -90, 19, 56, 73, -7, -49, -118, 107, -69, -39, -6, 82, -123, + 54, -10, -43, 16, -117, -59, 36, -49, 27] as byte[], + [-4, 111, -103, 111, 72, -106, 105, -19, 81, -123, 84, -13, -40, -53, -3, -97, -8, 43, -22, -2, -23, -15, 28, + 116, -63, 96, -79, -127, -84, 63, -6, -94] as byte[]] + H_sig << [[0, 0, 0, 7, 115, 115, 104, 45, 100, 115, 115, 0, 0, 0, 40, -113, -52, 88, -117, 80, -105, -92, -124, -49, + 56, -35, 90, -9, -128, 31, -33, -18, 13, -5, 7, 108, -2, 92, 108, 85, 58, 39, 99, 122, -118, 125, -121, 21, + -37, 2, 55, 109, -23, -125, 4] as byte[], + [0, 0, 0, 7, 115, 115, 104, 45, 100, 115, 115, 0, 0, 0, 40, 0, 79, 84, 118, -50, 11, -117, -112, 52, -25, + -78, -50, -20, 6, -69, -26, 7, 90, -34, -124, 80, 76, -32, -23, -8, 43, 38, -48, -89, -17, -60, -1, -78, + 112, -88, 14, -39, -78, -98, -80] as byte[]] + } + +}