Skip to content

Commit

Permalink
Use OkHttp's certificate creation code
Browse files Browse the repository at this point in the history
We don't implement the full feature set that Bouncycastle has, but
we also don't need it.

In follow up changes I intend to remove the Bouncycastle dependency
for everything but some test cases.
  • Loading branch information
squarejesse committed Jun 28, 2020
1 parent c3d453c commit 39f46f7
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 78 deletions.
198 changes: 125 additions & 73 deletions okhttp-tls/src/main/kotlin/okhttp3/tls/HeldCertificate.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package okhttp3.tls

import java.math.BigInteger
import java.net.InetAddress
import java.security.GeneralSecurityException
import java.security.KeyFactory
import java.security.KeyPair
Expand All @@ -24,27 +25,32 @@ import java.security.PrivateKey
import java.security.PublicKey
import java.security.SecureRandom
import java.security.Security
import java.security.Signature
import java.security.cert.X509Certificate
import java.security.interfaces.ECPublicKey
import java.security.interfaces.RSAPrivateKey
import java.security.interfaces.RSAPublicKey
import java.security.spec.PKCS8EncodedKeySpec
import java.util.Date
import java.util.UUID
import java.util.concurrent.TimeUnit
import javax.security.auth.x500.X500Principal
import okhttp3.internal.canParseAsIpAddress
import okhttp3.tls.internal.der.AlgorithmIdentifier
import okhttp3.tls.internal.der.AttributeTypeAndValue
import okhttp3.tls.internal.der.BasicConstraints
import okhttp3.tls.internal.der.BitString
import okhttp3.tls.internal.der.Certificate
import okhttp3.tls.internal.der.CertificateAdapters
import okhttp3.tls.internal.der.CertificateAdapters.generalNameDnsName
import okhttp3.tls.internal.der.CertificateAdapters.generalNameIpAddress
import okhttp3.tls.internal.der.Extension
import okhttp3.tls.internal.der.ObjectIdentifiers
import okhttp3.tls.internal.der.TbsCertificate
import okhttp3.tls.internal.der.Validity
import okio.ByteString
import okio.ByteString.Companion.decodeBase64
import okio.ByteString.Companion.toByteString
import org.bouncycastle.asn1.ASN1Encodable
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.X509Extensions
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.x509.X509V3CertificateGenerator

/**
* A certificate and its private key. These are some properties of certificates that are used with
Expand Down Expand Up @@ -315,88 +321,134 @@ class HeldCertificate(
}

fun build(): HeldCertificate {
// Subject, public & private keys for this certificate.
val heldKeyPair = keyPair ?: generateKeyPair()
val subject = buildSubject()

// Subject, public & private keys for this certificate's signer. It may be self signed!
val signedByKeyPair: KeyPair
val signedByPrincipal: X500Principal
// Subject keys & identity.
val subjectKeyPair = keyPair ?: generateKeyPair()
val subjectPublicKeyInfo = CertificateAdapters.subjectPublicKeyInfo.fromDer(
subjectKeyPair.public.encoded.toByteString()
)
val subject: List<List<AttributeTypeAndValue>> = subject()

// Issuer/signer keys & identity. May be the subject if it is self-signed.
val issuerKeyPair: KeyPair
val issuer: List<List<AttributeTypeAndValue>>
if (signedBy != null) {
signedByKeyPair = signedBy!!.keyPair
signedByPrincipal = signedBy!!.certificate.subjectX500Principal
} else {
signedByKeyPair = heldKeyPair
signedByPrincipal = subject
}

// Generate & sign the certificate.
val notBefore = if (this.notBefore != -1L) {
this.notBefore
} else {
System.currentTimeMillis()
}
val notAfter = if (this.notAfter != -1L) {
this.notAfter
issuerKeyPair = signedBy!!.keyPair
issuer = CertificateAdapters.rdnSequence.fromDer(
signedBy!!.certificate.subjectX500Principal.encoded.toByteString()
)
} else {
notBefore + DEFAULT_DURATION_MILLIS
issuerKeyPair = subjectKeyPair
issuer = subject
}
val serialNumber = if (this.serialNumber != null) {
this.serialNumber
} else {
BigInteger.ONE
val signatureAlgorithm = signatureAlgorithm(issuerKeyPair)

// Subset of certificate data that's covered by the signature.
val tbsCertificate = TbsCertificate(
version = 2L, // v3.
serialNumber = serialNumber ?: BigInteger.ONE,
signature = signatureAlgorithm,
issuer = issuer,
validity = validity(),
subject = subject,
subjectPublicKeyInfo = subjectPublicKeyInfo,
issuerUniqueID = null,
subjectUniqueID = null,
extensions = extensions()
)

// Signature.
val signature = Signature.getInstance(tbsCertificate.signatureAlgorithmName).run {
initSign(issuerKeyPair.private)
update(CertificateAdapters.tbsCertificate.toDer(tbsCertificate).toByteArray())
sign().toByteString()
}
val signatureAlgorithm = if (signedByKeyPair.private is RSAPrivateKey) {
"SHA256WithRSA"
} else {
"SHA256withECDSA"

// Complete signed certificate.
val certificate = Certificate(
tbsCertificate = tbsCertificate,
signatureAlgorithm = signatureAlgorithm,
signatureValue = BitString(
byteString = signature,
unusedBitsCount = 0
)
)

return HeldCertificate(subjectKeyPair, certificate.toX509Certificate())
}

private fun subject(): List<List<AttributeTypeAndValue>> {
val result = mutableListOf<List<AttributeTypeAndValue>>()

if (ou != null) {
result += listOf(AttributeTypeAndValue(
type = ObjectIdentifiers.organizationalUnitName,
value = ou
))
}
val generator = X509V3CertificateGenerator()
generator.setSerialNumber(serialNumber)
generator.setIssuerDN(signedByPrincipal)
generator.setNotBefore(Date(notBefore))
generator.setNotAfter(Date(notAfter))
generator.setSubjectDN(subject)
generator.setPublicKey(heldKeyPair.public)
generator.setSignatureAlgorithm(signatureAlgorithm)

result += listOf(AttributeTypeAndValue(
type = ObjectIdentifiers.commonName,
value = cn ?: UUID.randomUUID().toString()
))

return result
}

private fun validity(): Validity {
val notBefore = if (notBefore != -1L) notBefore else System.currentTimeMillis()
val notAfter = if (notAfter != -1L) notAfter else notBefore + DEFAULT_DURATION_MILLIS
return Validity(
notBefore = notBefore,
notAfter = notAfter
)
}

private fun extensions(): MutableList<Extension> {
val result = mutableListOf<Extension>()

if (maxIntermediateCas != -1) {
generator.addExtension(X509Extensions.BasicConstraints, true,
BasicConstraints(maxIntermediateCas))
result += Extension(
extnID = ObjectIdentifiers.basicConstraints,
critical = true,
extnValue = BasicConstraints(
ca = true,
pathLenConstraint = 3
)
)
}

if (altNames.isNotEmpty()) {
val encodableAltNames = arrayOfNulls<ASN1Encodable>(altNames.size)
for (i in 0 until altNames.size) {
val altName = altNames[i]
val tag = when {
altName.canParseAsIpAddress() -> GeneralName.iPAddress
else -> GeneralName.dNSName
val extensionValue = altNames.map {
when {
it.canParseAsIpAddress() -> {
generalNameIpAddress to InetAddress.getByName(it).address.toByteString()
}
else -> {
generalNameDnsName to it
}
}
encodableAltNames[i] = GeneralName(tag, altName)
}
generator.addExtension(X509Extensions.SubjectAlternativeName, true,
DERSequence(encodableAltNames))
result += Extension(
extnID = ObjectIdentifiers.subjectAlternativeName,
critical = true,
extnValue = extensionValue
)
}

val certificate = generator.generate(signedByKeyPair.private)
return HeldCertificate(heldKeyPair, certificate)
return result
}

private fun buildSubject(): X500Principal {
val name = buildString {
append("CN=")
if (cn != null) {
append(cn)
} else {
append(UUID.randomUUID())
}
if (ou != null) {
append(", OU=")
append(ou)
}
private fun signatureAlgorithm(signedByKeyPair: KeyPair): AlgorithmIdentifier {
return when (signedByKeyPair.private) {
is RSAPrivateKey -> AlgorithmIdentifier(
algorithm = ObjectIdentifiers.sha256WithRSAEncryption,
parameters = null
)
else -> AlgorithmIdentifier(
algorithm = ObjectIdentifiers.sha256withEcdsa,
parameters = ByteString.EMPTY
)
}
return X500Principal(name)
}

private fun generateKeyPair(): KeyPair {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ internal object ObjectIdentifiers {
const val commonName = "2.5.4.3"
const val organizationalUnitName = "2.5.4.11"
const val rsaEncryption = "1.2.840.113549.1.1.1"
const val sha256WithRSAEncryption = "1.2.840.113549.1.1.11"
const val ecPublicKey = "1.2.840.10045.2.1"
const val sha256WithRSAEncryption = "1.2.840.113549.1.1.11"
const val sha256withEcdsa = "1.2.840.10045.4.3.2"
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
package okhttp3.tls.internal.der

import java.math.BigInteger
import java.security.GeneralSecurityException
import java.security.PublicKey
import java.security.Signature
import java.security.SignatureException
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
import okio.Buffer

internal data class Certificate(
val tbsCertificate: TbsCertificate,
Expand Down Expand Up @@ -51,6 +58,32 @@ internal data class Certificate(
it.extnID == ObjectIdentifiers.basicConstraints
}
}

/** Returns true if the certificate was signed by [issuer]. */
@Throws(SignatureException::class)
fun checkSignature(issuer: PublicKey): Boolean {
val signedData = CertificateAdapters.tbsCertificate.toDer(tbsCertificate)

val signature = Signature.getInstance(tbsCertificate.signatureAlgorithmName)
signature.initVerify(issuer)
signature.update(signedData.toByteArray())
return signature.verify(signatureValue.byteString.toByteArray())
}

fun toX509Certificate(): X509Certificate {
val data = CertificateAdapters.certificate.toDer(this)
try {
val certificateFactory = CertificateFactory.getInstance("X.509")
val certificates = certificateFactory.generateCertificates(Buffer().write(data).inputStream())
return certificates.single() as X509Certificate
} catch (e: NoSuchElementException) {
throw IllegalArgumentException("failed to decode certificate", e)
} catch (e: IllegalArgumentException) {
throw IllegalArgumentException("failed to decode certificate", e)
} catch (e: GeneralSecurityException) {
throw IllegalArgumentException("failed to decode certificate", e)
}
}
}

internal data class TbsCertificate(
Expand All @@ -74,6 +107,19 @@ internal data class TbsCertificate(
/** Extensions ::= SEQUENCE SIZE (1..MAX) OF Extension */
val extensions: List<Extension>
) {
/**
* Returns the standard name of this certificate's signature algorithm as specified by
* [Signature.getInstance]. Typical values are like "SHA256WithRSA".
*/
val signatureAlgorithmName: String
get() {
return when (signature.algorithm) {
ObjectIdentifiers.sha256WithRSAEncryption -> "SHA256WithRSA"
ObjectIdentifiers.sha256withEcdsa -> "SHA256withECDSA"
else -> error("unexpected signature algorithm: ${signature.algorithm}")
}
}

// Avoid Long.hashCode(long) which isn't available on Android 5.
override fun hashCode(): Int {
var result = 0
Expand Down
8 changes: 4 additions & 4 deletions okhttp-tls/src/test/java/okhttp3/tls/HeldCertificateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ public final class HeldCertificateTest {
.signedBy(root)
.build();

assertThat(root.certificate().getSigAlgName()).isEqualTo("SHA256WITHRSA");
assertThat(leaf.certificate().getSigAlgName()).isEqualTo("SHA256WITHRSA");
assertThat(root.certificate().getSigAlgName()).isEqualToIgnoringCase("SHA256WITHRSA");
assertThat(leaf.certificate().getSigAlgName()).isEqualToIgnoringCase("SHA256WITHRSA");
}

@Test public void rsaSignedByEcdsa() {
Expand All @@ -223,8 +223,8 @@ public final class HeldCertificateTest {
.signedBy(root)
.build();

assertThat(root.certificate().getSigAlgName()).isEqualTo("SHA256WITHECDSA");
assertThat(leaf.certificate().getSigAlgName()).isEqualTo("SHA256WITHECDSA");
assertThat(root.certificate().getSigAlgName()).isEqualToIgnoringCase("SHA256WITHECDSA");
assertThat(leaf.certificate().getSigAlgName()).isEqualToIgnoringCase("SHA256WITHECDSA");
}

@Test public void decodeEcdsa256() throws Exception {
Expand Down
Loading

0 comments on commit 39f46f7

Please sign in to comment.