diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/HeldCertificate.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/HeldCertificate.kt index c5f9194f34ca..9671812b11a6 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/HeldCertificate.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/HeldCertificate.kt @@ -44,6 +44,11 @@ import okhttp3.tls.internal.der.CertificateAdapters.generalNameDnsName import okhttp3.tls.internal.der.CertificateAdapters.generalNameIpAddress import okhttp3.tls.internal.der.Extension import okhttp3.tls.internal.der.ObjectIdentifiers +import okhttp3.tls.internal.der.ObjectIdentifiers.basicConstraints +import okhttp3.tls.internal.der.ObjectIdentifiers.organizationalUnitName +import okhttp3.tls.internal.der.ObjectIdentifiers.sha256WithRSAEncryption +import okhttp3.tls.internal.der.ObjectIdentifiers.sha256withEcdsa +import okhttp3.tls.internal.der.ObjectIdentifiers.subjectAlternativeName import okhttp3.tls.internal.der.TbsCertificate import okhttp3.tls.internal.der.Validity import okio.ByteString @@ -187,8 +192,8 @@ class HeldCertificate( class Builder { private var notBefore = -1L private var notAfter = -1L - private var cn: String? = null - private var ou: String? = null + private var commonName: String? = null + private var organizationalUnit: String? = null private val altNames = mutableListOf() private var serialNumber: BigInteger? = null private var keyPair: KeyPair? = null @@ -240,12 +245,12 @@ class HeldCertificate( * [rfc_2818]: https://tools.ietf.org/html/rfc2818 */ fun commonName(cn: String) = apply { - this.cn = cn + this.commonName = cn } /** Sets the certificate's organizational unit (OU). If unset this field will be omitted. */ fun organizationalUnit(ou: String) = apply { - this.ou = ou + this.organizationalUnit = ou } /** Sets this certificate's serial number. If unset the serial number will be 1. */ @@ -378,16 +383,16 @@ class HeldCertificate( private fun subject(): List> { val result = mutableListOf>() - if (ou != null) { + if (organizationalUnit != null) { result += listOf(AttributeTypeAndValue( - type = ObjectIdentifiers.organizationalUnitName, - value = ou + type = organizationalUnitName, + value = organizationalUnit )) } result += listOf(AttributeTypeAndValue( type = ObjectIdentifiers.commonName, - value = cn ?: UUID.randomUUID().toString() + value = commonName ?: UUID.randomUUID().toString() )) return result @@ -407,11 +412,11 @@ class HeldCertificate( if (maxIntermediateCas != -1) { result += Extension( - extnID = ObjectIdentifiers.basicConstraints, + id = basicConstraints, critical = true, - extnValue = BasicConstraints( + value = BasicConstraints( ca = true, - pathLenConstraint = 3 + maxIntermediateCas = maxIntermediateCas.toLong() ) ) } @@ -428,9 +433,9 @@ class HeldCertificate( } } result += Extension( - extnID = ObjectIdentifiers.subjectAlternativeName, + id = subjectAlternativeName, critical = true, - extnValue = extensionValue + value = extensionValue ) } @@ -440,20 +445,21 @@ class HeldCertificate( private fun signatureAlgorithm(signedByKeyPair: KeyPair): AlgorithmIdentifier { return when (signedByKeyPair.private) { is RSAPrivateKey -> AlgorithmIdentifier( - algorithm = ObjectIdentifiers.sha256WithRSAEncryption, + algorithm = sha256WithRSAEncryption, parameters = null ) else -> AlgorithmIdentifier( - algorithm = ObjectIdentifiers.sha256withEcdsa, + algorithm = sha256withEcdsa, parameters = ByteString.EMPTY ) } } private fun generateKeyPair(): KeyPair { - val keyPairGenerator = KeyPairGenerator.getInstance(keyAlgorithm) - keyPairGenerator.initialize(keySize, SecureRandom()) - return keyPairGenerator.generateKeyPair() + return KeyPairGenerator.getInstance(keyAlgorithm).run { + initialize(keySize, SecureRandom()) + generateKeyPair() + } } companion object { diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/Adapters.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/Adapters.kt index 7b894e81583c..2f8423b95732 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/Adapters.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/Adapters.kt @@ -239,35 +239,29 @@ internal object Adapters { ): BasicDerAdapter { val codec = object : BasicDerAdapter.Codec { override fun decode(reader: DerReader): T { - reader.pushTypeHint() - try { + return reader.withTypeHint { val list = mutableListOf() while (list.size < members.size) { val member = members[list.size] - list += member.readValue(reader) + list += member.fromDer(reader) } if (reader.hasNext()) { - throw IOException("unexpected ${reader.peekHeader()}") + throw IOException("unexpected ${reader.peekHeader()} at $reader") } - return construct(list) - } finally { - reader.popTypeHint() + return@withTypeHint construct(list) } } override fun encode(writer: DerWriter, value: T) { val list = decompose(value) - writer.pushTypeHint() - try { + writer.withTypeHint { for (i in list.indices) { val adapter = members[i] as DerAdapter - adapter.writeValue(writer, list[i]) + adapter.toDer(writer, list[i]) } - } finally { - writer.popTypeHint() } } } @@ -285,19 +279,19 @@ internal object Adapters { return object : DerAdapter, Any?>> { override fun matches(header: DerHeader): Boolean = true - override fun readValue(reader: DerReader): Pair, Any?> { + override fun fromDer(reader: DerReader): Pair, Any?> { val peekedHeader = reader.peekHeader() - ?: throw IOException("expected a value") + ?: throw IOException("expected a value at $reader") val choice = choices.firstOrNull { it.matches(peekedHeader) } - ?: throw IOException("expected a matching choice but was $peekedHeader") + ?: throw IOException("expected a matching choice but was $peekedHeader at $reader") - return choice to choice.readValue(reader) + return choice to choice.fromDer(reader) } - override fun writeValue(writer: DerWriter, value: Pair, Any?>) { + override fun toDer(writer: DerWriter, value: Pair, Any?>) { val (adapter, v) = value - (adapter as DerAdapter).writeValue(writer, v) + (adapter as DerAdapter).toDer(writer, v) } override fun toString() = choices.joinToString(separator = " OR ") @@ -317,27 +311,23 @@ internal object Adapters { chooser: (Any?) -> DerAdapter<*>? ): DerAdapter { return object : DerAdapter { - override fun matches(header: DerHeader): Boolean = true + override fun matches(header: DerHeader) = true - override fun writeValue(writer: DerWriter, value: Any?) { + override fun toDer(writer: DerWriter, value: Any?) { // If we don't understand this hint, encode the body as a byte string. The byte string // will include a tag and length header as a prefix. - val adapter = chooser(writer.typeHint) - - if (adapter != null) { - (adapter as DerAdapter).writeValue(writer, value) - } else { - writer.writeOctetString(value as ByteString) + val adapter = chooser(writer.typeHint) as DerAdapter? + when { + adapter != null -> adapter.toDer(writer, value) + else -> writer.writeOctetString(value as ByteString) } } - override fun readValue(reader: DerReader): Any? { - val adapter = chooser(reader.typeHint) - - if (adapter != null) { - return (adapter as DerAdapter).readValue(reader) - } else { - return reader.readOctetString() + override fun fromDer(reader: DerReader): Any? { + val adapter = chooser(reader.typeHint) as DerAdapter? + return when { + adapter != null -> adapter.fromDer(reader) + else -> reader.readOctetString() } } } @@ -370,7 +360,7 @@ internal object Adapters { return object : DerAdapter { override fun matches(header: DerHeader): Boolean = true - override fun writeValue(writer: DerWriter, value: Any?) { + override fun toDer(writer: DerWriter, value: Any?) { when { isOptional && value == optionalValue -> { // Write nothing. @@ -385,7 +375,7 @@ internal object Adapters { else -> { for ((type, adapter) in choices) { if (type.isInstance(value) || (value == null && type == Unit::class)) { - (adapter as DerAdapter).writeValue(writer, value) + (adapter as DerAdapter).toDer(writer, value) return } } @@ -393,12 +383,12 @@ internal object Adapters { } } - override fun readValue(reader: DerReader): Any? { + override fun fromDer(reader: DerReader): Any? { if (isOptional && !reader.hasNext()) return optionalValue for ((_, adapter) in choices) { if (adapter.matches(reader.peekHeader()!!)) { - return adapter.readValue(reader) + return adapter.fromDer(reader) } } diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/BasicDerAdapter.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/BasicDerAdapter.kt index 04b77a64eaef..8a526133d1a7 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/BasicDerAdapter.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/BasicDerAdapter.kt @@ -52,7 +52,7 @@ internal data class BasicDerAdapter( override fun matches(header: DerHeader) = header.tagClass == tagClass && header.tag == tag - override fun readValue(reader: DerReader): T { + override fun fromDer(reader: DerReader): T { val peekedHeader = reader.peekHeader() if (peekedHeader == null || peekedHeader.tagClass != tagClass || peekedHeader.tag != tag) { if (isOptional) return defaultValue as T @@ -70,7 +70,7 @@ internal data class BasicDerAdapter( return result } - override fun writeValue(writer: DerWriter, value: T) { + override fun toDer(writer: DerWriter, value: T) { if (typeHint) { writer.typeHint = value } diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/CertificateAdapters.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/CertificateAdapters.kt index 032f80c78ea4..7805f919a506 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/CertificateAdapters.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/CertificateAdapters.kt @@ -24,6 +24,7 @@ import okio.IOException * * [rfc_5280]: https://tools.ietf.org/html/rfc5280 */ +@Suppress("UNCHECKED_CAST") // This needs to cast decoded collections. internal object CertificateAdapters { /** * ``` @@ -39,33 +40,33 @@ internal object CertificateAdapters { * > 2049 as UTCTime; certificate validity dates in 2050 or later MUST be encoded as * > GeneralizedTime. */ - internal val time = object : DerAdapter { + internal val time: DerAdapter = object : DerAdapter { override fun matches(header: DerHeader): Boolean { return Adapters.UTC_TIME.matches(header) || Adapters.GENERALIZED_TIME.matches(header) } - override fun readValue(reader: DerReader): Long { + override fun fromDer(reader: DerReader): Long { val peekHeader = reader.peekHeader() ?: throw IOException("expected time but was exhausted at $reader") return when { peekHeader.tagClass == Adapters.UTC_TIME.tagClass && peekHeader.tag == Adapters.UTC_TIME.tag -> { - Adapters.UTC_TIME.readValue(reader) + Adapters.UTC_TIME.fromDer(reader) } peekHeader.tagClass == Adapters.GENERALIZED_TIME.tagClass && peekHeader.tag == Adapters.GENERALIZED_TIME.tag -> { - Adapters.GENERALIZED_TIME.readValue(reader) + Adapters.GENERALIZED_TIME.fromDer(reader) } else -> throw IOException("expected time but was $peekHeader at $reader") } } - override fun writeValue(writer: DerWriter, value: Long) { + override fun toDer(writer: DerWriter, value: Long) { if (value < 2_524_608_000_000L) { // 2050-01-01T00:00:00Z - Adapters.UTC_TIME.writeValue(writer, value) + Adapters.UTC_TIME.toDer(writer, value) } else { - Adapters.GENERALIZED_TIME.writeValue(writer, value) + Adapters.GENERALIZED_TIME.toDer(writer, value) } } } @@ -78,12 +79,15 @@ internal object CertificateAdapters { * } * ``` */ - internal val validity = Adapters.sequence( + private val validity: BasicDerAdapter = Adapters.sequence( "Validity", time, time, decompose = { - listOf(it.notBefore, it.notAfter) + listOf( + it.notBefore, + it.notAfter + ) }, construct = { Validity( @@ -93,7 +97,8 @@ internal object CertificateAdapters { } ) - val algorithmParameters = Adapters.usingTypeHint { typeHint -> + /** The type of the parameters depends on the algorithm that precedes it. */ + private val algorithmParameters: DerAdapter = Adapters.usingTypeHint { typeHint -> when (typeHint) { // This type is pretty strange. The spec says that for certain algorithms we must encode null // when it is present, and for others we must omit it! @@ -113,12 +118,22 @@ internal object CertificateAdapters { * } * ``` */ - internal val algorithmIdentifier = Adapters.sequence( + internal val algorithmIdentifier: BasicDerAdapter = Adapters.sequence( "AlgorithmIdentifier", Adapters.OBJECT_IDENTIFIER.asTypeHint(), algorithmParameters, - decompose = { listOf(it.algorithm, it.parameters) }, - construct = { AlgorithmIdentifier(it[0] as String, it[1]) } + decompose = { + listOf( + it.algorithm, + it.parameters + ) + }, + construct = { + AlgorithmIdentifier( + algorithm = it[0] as String, + parameters = it[1] + ) + } ) /** @@ -129,12 +144,22 @@ internal object CertificateAdapters { * } * ``` */ - internal val basicConstraints = Adapters.sequence( + private val basicConstraints: BasicDerAdapter = Adapters.sequence( "BasicConstraints", Adapters.BOOLEAN.optional(defaultValue = false), Adapters.INTEGER_AS_LONG.optional(), - decompose = { listOf(it.ca, it.pathLenConstraint) }, - construct = { BasicConstraints(it[0] as Boolean, it[1] as Long?) } + decompose = { + listOf( + it.ca, + it.maxIntermediateCas + ) + }, + construct = { + BasicConstraints( + ca = it[0] as Boolean, + maxIntermediateCas = it[1] as Long? + ) + } ) /** @@ -153,10 +178,12 @@ internal object CertificateAdapters { * registeredID [8] OBJECT IDENTIFIER * } * ``` + * + * The first property of the pair is the adapter that was used, the second property is the value. */ internal val generalNameDnsName = Adapters.IA5_STRING.withTag(tag = 2L) internal val generalNameIpAddress = Adapters.OCTET_STRING.withTag(tag = 7L) - internal val generalName = Adapters.choice( + internal val generalName: DerAdapter, Any?>> = Adapters.choice( generalNameDnsName, generalNameIpAddress ) @@ -168,13 +195,14 @@ internal object CertificateAdapters { * GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName * ``` */ - internal val subjectAlternativeName = generalName.asSequenceOf() + private val subjectAlternativeName: BasicDerAdapter, Any?>>> = + generalName.asSequenceOf() /** * This uses the preceding extension ID to select which adapter to use for the extension value * that follows. */ - internal val extensionValue = Adapters.usingTypeHint { typeHint -> + private val extensionValue: BasicDerAdapter = Adapters.usingTypeHint { typeHint -> when (typeHint) { ObjectIdentifiers.subjectAlternativeName -> subjectAlternativeName ObjectIdentifiers.basicConstraints -> basicConstraints @@ -198,13 +226,25 @@ internal object CertificateAdapters { * } * ``` */ - internal val extension = Adapters.sequence( + internal val extension: BasicDerAdapter = Adapters.sequence( "Extension", Adapters.OBJECT_IDENTIFIER.asTypeHint(), Adapters.BOOLEAN.optional(defaultValue = false), extensionValue, - decompose = { listOf(it.extnID, it.critical, it.extnValue) }, - construct = { Extension(it[0] as String, it[1] as Boolean, it[2]) } + decompose = { + listOf( + it.id, + it.critical, + it.value + ) + }, + construct = { + Extension( + id = it[0] as String, + critical = it[1] as Boolean, + value = it[2] + ) + } ) /** @@ -219,12 +259,22 @@ internal object CertificateAdapters { * AttributeValue ::= ANY -- DEFINED BY AttributeType * ``` */ - internal val attributeTypeAndValue = Adapters.sequence( + private val attributeTypeAndValue: BasicDerAdapter = Adapters.sequence( "AttributeTypeAndValue", Adapters.OBJECT_IDENTIFIER, Adapters.any(), - decompose = { listOf(it.type, it.value) }, - construct = { AttributeTypeAndValue(it[0] as String, it[1]) } + decompose = { + listOf( + it.type, + it.value + ) + }, + construct = { + AttributeTypeAndValue( + type = it[0] as String, + value = it[1] + ) + } ) /** @@ -234,7 +284,8 @@ internal object CertificateAdapters { * RelativeDistinguishedName ::= SET SIZE (1..MAX) OF AttributeTypeAndValue * ``` */ - internal val rdnSequence = attributeTypeAndValue.asSetOf().asSequenceOf() + internal val rdnSequence: BasicDerAdapter>> = + attributeTypeAndValue.asSetOf().asSequenceOf() /** * ``` @@ -244,7 +295,7 @@ internal object CertificateAdapters { * } * ``` */ - internal val name = Adapters.choice( + internal val name: DerAdapter, Any?>> = Adapters.choice( rdnSequence ) @@ -256,12 +307,22 @@ internal object CertificateAdapters { * } * ``` */ - internal val subjectPublicKeyInfo = Adapters.sequence( + internal val subjectPublicKeyInfo: BasicDerAdapter = Adapters.sequence( "SubjectPublicKeyInfo", algorithmIdentifier, Adapters.BIT_STRING, - decompose = { listOf(it.algorithm, it.subjectPublicKey) }, - construct = { SubjectPublicKeyInfo(it[0] as AlgorithmIdentifier, it[1] as BitString) } + decompose = { + listOf( + it.algorithm, + it.subjectPublicKey + ) + }, + construct = { + SubjectPublicKeyInfo( + algorithm = it[0] as AlgorithmIdentifier, + subjectPublicKey = it[1] as BitString + ) + } ) /** @@ -280,7 +341,7 @@ internal object CertificateAdapters { * } * ``` */ - internal val tbsCertificate = Adapters.sequence( + internal val tbsCertificate: BasicDerAdapter = Adapters.sequence( "TBSCertificate", Adapters.INTEGER_AS_LONG.withExplicitBox(tag = 0L).optional(defaultValue = 0), // v1 == 0 Adapters.INTEGER_AS_BIG_INTEGER, @@ -331,7 +392,7 @@ internal object CertificateAdapters { * } * ``` */ - internal val certificate = Adapters.sequence( + internal val certificate: BasicDerAdapter = Adapters.sequence( "Certificate", tbsCertificate, algorithmIdentifier, @@ -352,7 +413,28 @@ internal object CertificateAdapters { } ) - internal val privateKeyInfo = Adapters.sequence( + /** + * ``` + * Version ::= INTEGER { v1(0), v2(1) } (v1, ..., v2) + * + * PrivateKeyAlgorithmIdentifier ::= AlgorithmIdentifier + * + * PrivateKey ::= OCTET STRING + * + * OneAsymmetricKey ::= SEQUENCE { + * version Version, + * privateKeyAlgorithm PrivateKeyAlgorithmIdentifier, + * privateKey PrivateKey, + * attributes [0] Attributes OPTIONAL, + * ..., + * [[2: publicKey [1] PublicKey OPTIONAL ]], + * ... + * } + * + * PrivateKeyInfo ::= OneAsymmetricKey + * ``` + */ + internal val privateKeyInfo: BasicDerAdapter = Adapters.sequence( "PrivateKeyInfo", Adapters.INTEGER_AS_LONG, algorithmIdentifier, diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerAdapter.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerAdapter.kt index f0c32d884e25..eccb85019bf4 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerAdapter.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerAdapter.kt @@ -18,6 +18,9 @@ package okhttp3.tls.internal.der import okio.Buffer import okio.ByteString +/** + * Encode and decode a model object like a [Long] or [Certificate] as DER bytes. + */ internal interface DerAdapter { /** Returns true if this adapter can read [header] in a choice. */ fun matches(header: DerHeader): Boolean @@ -34,28 +37,28 @@ internal interface DerAdapter { * * If there's nothing to read and no default value, this will throw an exception. */ - fun readValue(reader: DerReader): T + fun fromDer(reader: DerReader): T + + fun fromDer(byteString: ByteString): T { + val buffer = Buffer().write(byteString) + val reader = DerReader(buffer) + return fromDer(reader) + } /** * Writes [value] to this adapter, unless it is the default value and can be safely omitted. * * If this does write a value, it will write a tag and a length and a full value. */ - fun writeValue(writer: DerWriter, value: T) + fun toDer(writer: DerWriter, value: T) fun toDer(value: T): ByteString { val buffer = Buffer() val writer = DerWriter(buffer) - writeValue(writer, value) + toDer(writer, value) return buffer.readByteString() } - fun fromDer(byteString: ByteString): T { - val buffer = Buffer().write(byteString) - val reader = DerReader(buffer) - return readValue(reader) - } - /** * Returns an adapter that expects this value wrapped by another value. Typically this occurs * when a value has both a context or application tag and a universal tag. @@ -65,6 +68,10 @@ internal interface DerAdapter { * ``` * [5] EXPLICIT UTF8String * ``` + * + * @param forceConstructed non-null to set the constructed bit to the specified value, even if the + * writing process sets something else. This is used to encode SEQUENCES in values that are + * declared to have non-constructed values, like OCTET STRING values. */ @Suppress("UNCHECKED_CAST") // read() produces a single element of the expected type. fun withExplicitBox( @@ -73,9 +80,9 @@ internal interface DerAdapter { forceConstructed: Boolean? = null ): BasicDerAdapter { val codec = object : BasicDerAdapter.Codec { - override fun decode(reader: DerReader) = readValue(reader) + override fun decode(reader: DerReader) = fromDer(reader) override fun encode(writer: DerWriter, value: T) { - writeValue(writer, value) + toDer(writer, value) if (forceConstructed != null) { writer.constructed = forceConstructed } @@ -99,14 +106,14 @@ internal interface DerAdapter { val codec = object : BasicDerAdapter.Codec> { override fun encode(writer: DerWriter, value: List) { for (v in value) { - writeValue(writer, v) + toDer(writer, v) } } override fun decode(reader: DerReader): List { val result = mutableListOf() while (reader.hasNext()) { - result += readValue(reader) + result += fromDer(reader) } return result } diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerReader.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerReader.kt index adc54fba7db7..5d7ffc0247b0 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerReader.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerReader.kt @@ -47,9 +47,9 @@ internal class DerReader(source: Source) { get() = countingSource.bytesRead - source.buffer.size /** How many bytes to read before [peekHeader] should return false, or -1L for no limit. */ - private var limit: Long = -1L + private var limit = -1L - /** Type hints scoped to the call stack, manipulated with [pushTypeHint] and [popTypeHint]. */ + /** Type hints scoped to the call stack, manipulated with [withTypeHint]. */ private val typeHintStack = mutableListOf() /** @@ -59,18 +59,18 @@ internal class DerReader(source: Source) { var typeHint: Any? get() = typeHintStack.lastOrNull() set(value) { - typeHintStack.set(typeHintStack.size - 1, value) + typeHintStack[typeHintStack.size - 1] = value } /** Names leading to the current location in the ASN.1 document. */ private val path = mutableListOf() - private var constructed: Boolean = false + private var constructed = false private var peekedHeader: DerHeader? = null private val bytesLeft: Long - get() = if (limit == -1L) -1L else limit - byteCount + get() = if (limit == -1L) -1L else (limit - byteCount) fun hasNext(): Boolean = peekHeader() != null @@ -110,35 +110,36 @@ internal class DerReader(source: Source) { if (limit == -1L && source.exhausted()) return END_OF_DATA // Read the tag. - val tag: Long val tagAndClass = source.readByte().toInt() and 0xff val tagClass = tagAndClass and 0b1100_0000 val constructed = (tagAndClass and 0b0010_0000) == 0b0010_0000 val tag0 = tagAndClass and 0b0001_1111 - if (tag0 == 0b0001_1111) { - tag = readVariableLengthLong() - } else { - tag = tag0.toLong() + val tag = when (tag0) { + 0b0001_1111 -> readVariableLengthLong() + else -> tag0.toLong() } // Read the length. - val length: Long val length0 = source.readByte().toInt() and 0xff - if (length0 == 0b1000_0000) { - // Indefinite length. - length = -1L - } else if ((length0 and 0b1000_0000) == 0b1000_0000) { - // Length specified over multiple bytes. - val lengthBytes = length0 and 0b0111_1111 - var lengthBits = source.readByte().toLong() and 0xff - for (i in 1 until lengthBytes) { - lengthBits = lengthBits shl 8 - lengthBits += source.readByte().toInt() and 0xff + val length = when { + length0 == 0b1000_0000 -> { + // Indefinite length. + -1L + } + (length0 and 0b1000_0000) == 0b1000_0000 -> { + // Length specified over multiple bytes. + val lengthBytes = length0 and 0b0111_1111 + var lengthBits = source.readByte().toLong() and 0xff + for (i in 1 until lengthBytes) { + lengthBits = lengthBits shl 8 + lengthBits += source.readByte().toInt() and 0xff + } + lengthBits + } + else -> { + // Length is 127 or fewer bytes. + (length0 and 0b0111_1111).toLong() } - length = lengthBits - } else { - // Length is 127 or fewer bytes. - length = (length0 and 0b0111_1111).toLong() } // Note that this may be be an encoded "end of data" header. @@ -177,29 +178,34 @@ internal class DerReader(source: Source) { } } - fun pushTypeHint() { + /** + * Execute [block] with a new namespace for type hints. Type hints from the enclosing type are no + * longer usable by the current type's members. + */ + fun withTypeHint(block: () -> T): T { typeHintStack.add(null) - } - - fun popTypeHint() { - typeHintStack.removeAt(typeHintStack.size - 1) + try { + return block() + } finally { + typeHintStack.removeAt(typeHintStack.size - 1) + } } fun readBoolean(): Boolean { - if (bytesLeft != 1L) throw ProtocolException("unexpected length: $bytesLeft") + if (bytesLeft != 1L) throw ProtocolException("unexpected length: $bytesLeft at $this") return source.readByte().toInt() != 0 } fun readBigInteger(): BigInteger { - if (bytesLeft == 0L) throw ProtocolException("unexpected length: $bytesLeft") + if (bytesLeft == 0L) throw ProtocolException("unexpected length: $bytesLeft at $this") val byteArray = source.readByteArray(bytesLeft) return BigInteger(byteArray) } fun readLong(): Long { - if (bytesLeft !in 1..8) throw ProtocolException("unexpected length: $bytesLeft") + if (bytesLeft !in 1..8) throw ProtocolException("unexpected length: $bytesLeft at $this") - var result = source.readByte().toLong() // No "and 0xff" because this is a signed value. + var result = source.readByte().toLong() // No "and 0xff" because this is a signed value! while (byteCount < limit) { result = result shl 8 result += source.readByte().toInt() and 0xff @@ -314,29 +320,29 @@ internal class DerReader(source: Source) { override fun toString() = path.joinToString(separator = " / ") - /** A source that keeps track of how many bytes it's consumed. */ - private class CountingSource(source: Source) : ForwardingSource(source) { - var bytesRead = 0L - - override fun read(sink: Buffer, byteCount: Long): Long { - val result = delegate.read(sink, byteCount) - if (result == -1L) return -1L - bytesRead += result - return result - } - } - companion object { /** * A synthetic value that indicates there's no more bytes. Values with equivalent data may also * show up in ASN.1 streams to also indicate the end of SEQUENCE, SET or other constructed * value. */ - val END_OF_DATA = DerHeader( + private val END_OF_DATA = DerHeader( tagClass = DerHeader.TAG_CLASS_UNIVERSAL, tag = DerHeader.TAG_END_OF_CONTENTS, constructed = false, length = -1L ) } + + /** A source that keeps track of how many bytes it's consumed. */ + private class CountingSource(source: Source) : ForwardingSource(source) { + var bytesRead = 0L + + override fun read(sink: Buffer, byteCount: Long): Long { + val result = delegate.read(sink, byteCount) + if (result == -1L) return -1L + bytesRead += result + return result + } + } } diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerWriter.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerWriter.kt index c76776d0b281..7a169671274d 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerWriter.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerWriter.kt @@ -93,19 +93,16 @@ internal class DerWriter(sink: BufferedSink) { } /** - * Create a new namespace for type hints. Type hints from the enclosing type are no longer usable - * by the current type's members. + * Execute [block] with a new namespace for type hints. Type hints from the enclosing type are no + * longer usable by the current type's members. */ - fun pushTypeHint() { + fun withTypeHint(block: () -> T): T { typeHintStack.add(null) - } - - /** - * Remove the current namespace when it is going out of scope. Calls to [pushTypeHint] and - * [popTypeHint] should be balanced. - */ - fun popTypeHint() { - typeHintStack.removeAt(typeHintStack.size - 1) + try { + return block() + } finally { + typeHintStack.removeAt(typeHintStack.size - 1) + } } private fun sink(): BufferedSink = stack[stack.size - 1] diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/ObjectIdentifiers.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/ObjectIdentifiers.kt index 492d20809656..4bafafb7f7d7 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/ObjectIdentifiers.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/ObjectIdentifiers.kt @@ -17,12 +17,12 @@ package okhttp3.tls.internal.der /** ASN.1 object identifiers used internally by this implementation. */ internal object ObjectIdentifiers { + const val ecPublicKey = "1.2.840.10045.2.1" + const val sha256withEcdsa = "1.2.840.10045.4.3.2" + const val rsaEncryption = "1.2.840.113549.1.1.1" + const val sha256WithRSAEncryption = "1.2.840.113549.1.1.11" const val subjectAlternativeName = "2.5.29.17" const val basicConstraints = "2.5.29.19" const val commonName = "2.5.4.3" const val organizationalUnitName = "2.5.4.11" - const val rsaEncryption = "1.2.840.113549.1.1.1" - const val ecPublicKey = "1.2.840.10045.2.1" - const val sha256WithRSAEncryption = "1.2.840.113549.1.1.11" - const val sha256withEcdsa = "1.2.840.10045.4.3.2" } diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/certificates.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/certificates.kt index 41b6897b7187..19b6b7c667df 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/certificates.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/certificates.kt @@ -49,14 +49,14 @@ internal data class Certificate( val subjectAlternativeNames: Extension get() { return tbsCertificate.extensions.first { - it.extnID == ObjectIdentifiers.subjectAlternativeName + it.id == ObjectIdentifiers.subjectAlternativeName } } val basicConstraints: Extension get() { return tbsCertificate.extensions.first { - it.extnID == ObjectIdentifiers.basicConstraints + it.id == ObjectIdentifiers.basicConstraints } } @@ -65,10 +65,11 @@ internal data class Certificate( fun checkSignature(issuer: PublicKey): Boolean { val signedData = CertificateAdapters.tbsCertificate.toDer(tbsCertificate) - val signature = Signature.getInstance(tbsCertificate.signatureAlgorithmName) - signature.initVerify(issuer) - signature.update(signedData.toByteArray()) - return signature.verify(signatureValue.byteString.toByteArray()) + return Signature.getInstance(tbsCertificate.signatureAlgorithmName).run { + initVerify(issuer) + update(signedData.toByteArray()) + verify(signatureValue.byteString.toByteArray()) + } } fun toX509Certificate(): X509Certificate { @@ -88,24 +89,16 @@ internal data class Certificate( } internal data class TbsCertificate( - /** Version ::= INTEGER { v1(0), v2(1), v3(2) } */ + /** This is a integer enum. Use 0L for v1, 1L for v2, and 2L for v3. */ val version: Long, - - /** CertificateSerialNumber ::= INTEGER */ val serialNumber: BigInteger, val signature: AlgorithmIdentifier, val issuer: List>, val validity: Validity, val subject: List>, val subjectPublicKeyInfo: SubjectPublicKeyInfo, - - /** UniqueIdentifier ::= BIT STRING */ val issuerUniqueID: BitString?, - - /** UniqueIdentifier ::= BIT STRING */ val subjectUniqueID: BitString?, - - /** Extensions ::= SEQUENCE SIZE (1..MAX) OF Extension */ val extensions: List ) { /** @@ -139,11 +132,14 @@ internal data class TbsCertificate( } internal data class AlgorithmIdentifier( + /** An OID string like "1.2.840.113549.1.1.11" for sha256WithRSAEncryption. */ val algorithm: String, + /** Parameters of a type implied by [algorithm]. */ val parameters: Any? ) internal data class AttributeTypeAndValue( + /** An OID string like "2.5.4.11" for organizationalUnitName. */ val type: String, val value: Any? ) @@ -167,39 +163,19 @@ internal data class SubjectPublicKeyInfo( ) internal data class Extension( - val extnID: String, + val id: String, val critical: Boolean, - val extnValue: Any? + val value: Any? ) internal data class BasicConstraints( + /** True if this certificate can be used as a Certificate Authority (CA). */ val ca: Boolean, - val pathLenConstraint: Long? + /** The maximum number of intermediate CAs between this and leaf certificates. */ + val maxIntermediateCas: Long? ) -/** - * A private key. Note that this class doesn't support attributes or an embedded public key. - * - * ``` - * Version ::= INTEGER { v1(0), v2(1) } (v1, ..., v2) - * - * PrivateKeyAlgorithmIdentifier ::= AlgorithmIdentifier - * - * PrivateKey ::= OCTET STRING - * - * OneAsymmetricKey ::= SEQUENCE { - * version Version, - * privateKeyAlgorithm PrivateKeyAlgorithmIdentifier, - * privateKey PrivateKey, - * attributes [0] Attributes OPTIONAL, - * ..., - * [[2: publicKey [1] PublicKey OPTIONAL ]], - * ... - * } - * - * PrivateKeyInfo ::= OneAsymmetricKey - * ``` - */ +/** A private key. Note that this class doesn't support attributes or an embedded public key. */ internal data class PrivateKeyInfo( val version: Long, // v1(0), v2(1) val algorithmIdentifier: AlgorithmIdentifier, // v1(0), v2(1) diff --git a/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerCertificatesTest.kt b/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerCertificatesTest.kt index e66d2af45a4e..69192dffafe9 100644 --- a/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerCertificatesTest.kt +++ b/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerCertificatesTest.kt @@ -24,8 +24,12 @@ import java.util.Date import java.util.TimeZone import okhttp3.tls.HeldCertificate import okhttp3.tls.decodeCertificatePem +import okhttp3.tls.internal.der.ObjectIdentifiers.basicConstraints +import okhttp3.tls.internal.der.ObjectIdentifiers.commonName +import okhttp3.tls.internal.der.ObjectIdentifiers.organizationalUnitName import okhttp3.tls.internal.der.ObjectIdentifiers.rsaEncryption import okhttp3.tls.internal.der.ObjectIdentifiers.sha256WithRSAEncryption +import okhttp3.tls.internal.der.ObjectIdentifiers.subjectAlternativeName import okio.Buffer import okio.ByteString import okio.ByteString.Companion.decodeBase64 @@ -39,18 +43,14 @@ internal class DerCertificatesTest { private val country = "1.3.6.1.4.1.311.60.2.1.3" private val certificateTransparencySignedCertificateTimestamps = "1.3.6.1.4.1.11129.2.4.2" private val authorityInfoAccess = "1.3.6.1.5.5.7.1.1" - private val commonName = "2.5.4.3" private val serialNumber = "2.5.4.5" private val countryName = "2.5.4.6" private val localityName = "2.5.4.7" private val stateOrProvinceName = "2.5.4.8" private val organizationName = "2.5.4.10" - private val organizationalUnitName = "2.5.4.11" private val businessCategory = "2.5.4.15" private val subjectKeyIdentifier = "2.5.29.14" private val keyUsage = "2.5.29.15" - private val subjectAltName = "2.5.29.17" - private val basicConstraints = "2.5.29.19" private val crlDistributionPoints = "2.5.29.31" private val certificatePolicies = "2.5.29.32" private val authorityKeyIdentifier = "2.5.29.35" @@ -298,45 +298,45 @@ internal class DerCertificatesTest { subjectUniqueID = null, extensions = listOf( Extension( - extnID = keyUsage, + id = keyUsage, critical = true, - extnValue = "03020106".decodeHex() + value = "03020106".decodeHex() ), Extension( - extnID = basicConstraints, + id = basicConstraints, critical = true, - extnValue = BasicConstraints( + value = BasicConstraints( ca = true, - pathLenConstraint = 1L + maxIntermediateCas = 1L ) ), Extension( - extnID = authorityInfoAccess, + id = authorityInfoAccess, critical = false, - extnValue = ("3025302306082b060105050730018617687474703a2f2f6f6373702e656" + + value = ("3025302306082b060105050730018617687474703a2f2f6f6373702e656" + "e74727573742e6e6574").decodeHex() ), Extension( - extnID = crlDistributionPoints, + id = crlDistributionPoints, critical = false, - extnValue = ("302a3028a026a0248622687474703a2f2f63726c2e656e74727573742e6" + + value = ("302a3028a026a0248622687474703a2f2f63726c2e656e74727573742e6" + "e65742f726f6f746361312e63726c").decodeHex() ), Extension( - extnID = certificatePolicies, + id = certificatePolicies, critical = false, - extnValue = ("303230300604551d20003028302606082b06010505070201161a6874747" + + value = ("303230300604551d20003028302606082b06010505070201161a6874747" + "03a2f2f7777772e656e74727573742e6e65742f435053").decodeHex() ), Extension( - extnID = subjectKeyIdentifier, + id = subjectKeyIdentifier, critical = false, - extnValue = "04146a72267ad01eef7de73b6951d46c8d9f901266ab".decodeHex() + value = "04146a72267ad01eef7de73b6951d46c8d9f901266ab".decodeHex() ), Extension( - extnID = authorityKeyIdentifier, + id = authorityKeyIdentifier, critical = false, - extnValue = "301680146890e467a4a65380c78666a4f1f74b43fb84bd6d".decodeHex() + value = "301680146890e467a4a65380c78666a4f1f74b43fb84bd6d".decodeHex() ) ) ), @@ -539,17 +539,17 @@ internal class DerCertificatesTest { subjectUniqueID = null, extensions = listOf( Extension( - extnID = subjectAltName, + id = subjectAlternativeName, critical = false, - extnValue = listOf( + value = listOf( CertificateAdapters.generalNameDnsName to "cash.app", CertificateAdapters.generalNameDnsName to "www.cash.app" ) ), Extension( - extnID = certificateTransparencySignedCertificateTimestamps, + id = certificateTransparencySignedCertificateTimestamps, critical = false, - extnValue = ("0482016b01690077005614069a2fd7c2ecd3f5e1bd44b23ec74676b9bc9" + + value = ("0482016b01690077005614069a2fd7c2ecd3f5e1bd44b23ec74676b9bc9" + "9115cc0ef949855d689d0dd0000017173d3269b0000040300483046022100a9e58ad" + "ee5adf4b5f5a7797480f80dc58041d78da8aad44c9cc0416a74cacb62022100eb463" + "ecf46c5725dfd50471804e4c665e8ae9790129b69706502a3e96fccf685007700877" + @@ -563,51 +563,51 @@ internal class DerCertificatesTest { .decodeHex() ), Extension( - extnID = keyUsage, + id = keyUsage, critical = true, - extnValue = "030205a0".decodeHex() + value = "030205a0".decodeHex() ), Extension( - extnID = extendedKeyUsage, + id = extendedKeyUsage, critical = false, - extnValue = "301406082b0601050507030106082b06010505070302".decodeHex() + value = "301406082b0601050507030106082b06010505070302".decodeHex() ), Extension( - extnID = authorityInfoAccess, + id = authorityInfoAccess, critical = false, - extnValue = ("305a302306082b060105050730018617687474703a2f2f6f6373702e656" + + value = ("305a302306082b060105050730018617687474703a2f2f6f6373702e656" + "e74727573742e6e6574303306082b060105050730028627687474703a2f2f6169612" + "e656e74727573742e6e65742f6c316d2d636861696e3235362e636572").decodeHex() ), Extension( - extnID = crlDistributionPoints, + id = crlDistributionPoints, critical = false, - extnValue = ("302a3028a026a0248622687474703a2f2f63726c2e656e74727573742e6" + + value = ("302a3028a026a0248622687474703a2f2f63726c2e656e74727573742e6" + "e65742f6c6576656c316d2e63726c").decodeHex() ), Extension( - extnID = certificatePolicies, + id = certificatePolicies, critical = false, - extnValue = ("30413036060a6086480186fa6c0a01023028302606082b0601050507020" + + value = ("30413036060a6086480186fa6c0a01023028302606082b0601050507020" + "1161a687474703a2f2f7777772e656e74727573742e6e65742f72706130070605678" + "10c0101").decodeHex() ), Extension( - extnID = authorityKeyIdentifier, + id = authorityKeyIdentifier, critical = false, - extnValue = ("30168014c3f7d0b52a30adaf0d9121703954ddbc8970c73a").decodeHex() + value = ("30168014c3f7d0b52a30adaf0d9121703954ddbc8970c73a").decodeHex() ), Extension( - extnID = subjectKeyIdentifier, + id = subjectKeyIdentifier, critical = false, - extnValue = "041475fd24c2df592599e32f3373e18c0450dd1b87b6".decodeHex() + value = "041475fd24c2df592599e32f3373e18c0450dd1b87b6".decodeHex() ), Extension( - extnID = basicConstraints, + id = basicConstraints, critical = false, - extnValue = BasicConstraints( + value = BasicConstraints( ca = false, - pathLenConstraint = null + maxIntermediateCas = null ) ) ) @@ -642,16 +642,16 @@ internal class DerCertificatesTest { .fromDer(certificateByteString) assertThat(okHttpCertificate.basicConstraints).isEqualTo(Extension( - extnID = ObjectIdentifiers.basicConstraints, + id = basicConstraints, critical = true, - extnValue = BasicConstraints(true, 3) + value = BasicConstraints(true, 3) )) assertThat(okHttpCertificate.commonName).isEqualTo("Jurassic Park") assertThat(okHttpCertificate.organizationalUnitName).isEqualTo("Gene Research") assertThat(okHttpCertificate.subjectAlternativeNames).isEqualTo(Extension( - extnID = ObjectIdentifiers.subjectAlternativeName, + id = subjectAlternativeName, critical = true, - extnValue = listOf( + value = listOf( CertificateAdapters.generalNameDnsName to "*.example.com", CertificateAdapters.generalNameDnsName to "www.example.org" ) @@ -677,7 +677,7 @@ internal class DerCertificatesTest { "mIE65swMM5/RNhS4aFjez/MwxFNOHaxc9VgCwYPXCLOtdf7AVovdyG0XWgbUXH+NyxKwboE").decodeBase64()!! val x509PublicKey = encodeKey( - algorithm = "1.2.840.113549.1.1.1", + algorithm = rsaEncryption, publicKeyBytes = publicKeyBytes ) val keyFactory = KeyFactory.getInstance("RSA") diff --git a/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerTest.kt b/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerTest.kt index 0052566c0b26..c5cb53f62e2b 100644 --- a/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerTest.kt +++ b/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerTest.kt @@ -22,6 +22,10 @@ import java.util.Date import java.util.TimeZone import okhttp3.tls.internal.der.CertificateAdapters.generalNameDnsName import okhttp3.tls.internal.der.CertificateAdapters.generalNameIpAddress +import okhttp3.tls.internal.der.ObjectIdentifiers.basicConstraints +import okhttp3.tls.internal.der.ObjectIdentifiers.commonName +import okhttp3.tls.internal.der.ObjectIdentifiers.sha256WithRSAEncryption +import okhttp3.tls.internal.der.ObjectIdentifiers.subjectAlternativeName import okio.Buffer import okio.ByteString.Companion.decodeHex import okio.ByteString.Companion.encodeUtf8 @@ -695,11 +699,11 @@ internal class DerTest { @Test fun `decode object identifier`() { val objectIdentifier = Adapters.OBJECT_IDENTIFIER.fromDer("06092a864886f70d01010b".decodeHex()) - assertThat(objectIdentifier).isEqualTo("1.2.840.113549.1.1.11") + assertThat(objectIdentifier).isEqualTo(sha256WithRSAEncryption) } @Test fun `encode object identifier`() { - val byteString = Adapters.OBJECT_IDENTIFIER.toDer("1.2.840.113549.1.1.11") + val byteString = Adapters.OBJECT_IDENTIFIER.toDer(sha256WithRSAEncryption) assertThat(byteString).isEqualTo("06092a864886f70d01010b".decodeHex()) } @@ -718,7 +722,7 @@ internal class DerTest { .fromDer("300d06092a864886f70d01010b0500".decodeHex()) assertThat(algorithmIdentifier).isEqualTo( AlgorithmIdentifier( - algorithm = "1.2.840.113549.1.1.11", + algorithm = sha256WithRSAEncryption, parameters = null ) ) @@ -727,7 +731,7 @@ internal class DerTest { @Test fun `encode sequence algorithm`() { val byteString = CertificateAdapters.algorithmIdentifier.toDer( AlgorithmIdentifier( - algorithm = "1.2.840.113549.1.1.11", + algorithm = sha256WithRSAEncryption, parameters = null ) ) @@ -781,7 +785,7 @@ internal class DerTest { @Test fun `extension with type hint for basic constraints`() { val extension = Extension( - ObjectIdentifiers.basicConstraints, + basicConstraints, false, BasicConstraints(true, 4) ) @@ -795,7 +799,7 @@ internal class DerTest { @Test fun `extension with type hint for subject alternative names`() { val extension = Extension( - ObjectIdentifiers.subjectAlternativeName, + subjectAlternativeName, false, listOf( generalNameDnsName to "cash.app", @@ -812,7 +816,7 @@ internal class DerTest { @Test fun `extension with unknown type hint`() { val extension = Extension( - "2.5.4.3", // common name is not an extension. + commonName, // common name is not an extension. false, "3006800109810109".decodeHex() )