Skip to content

Commit

Permalink
Improve code style for the new DER package (#6157)
Browse files Browse the repository at this point in the history
Improve docs, fix some names, fix some internal APIs.
  • Loading branch information
swankjesse authored Jun 30, 2020
1 parent 45df82e commit 353a52b
Show file tree
Hide file tree
Showing 11 changed files with 326 additions and 258 deletions.
42 changes: 24 additions & 18 deletions okhttp-tls/src/main/kotlin/okhttp3/tls/HeldCertificate.kt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ import okhttp3.tls.internal.der.CertificateAdapters.generalNameDnsName
import okhttp3.tls.internal.der.CertificateAdapters.generalNameIpAddress
import okhttp3.tls.internal.der.Extension
import okhttp3.tls.internal.der.ObjectIdentifiers
import okhttp3.tls.internal.der.ObjectIdentifiers.basicConstraints
import okhttp3.tls.internal.der.ObjectIdentifiers.organizationalUnitName
import okhttp3.tls.internal.der.ObjectIdentifiers.sha256WithRSAEncryption
import okhttp3.tls.internal.der.ObjectIdentifiers.sha256withEcdsa
import okhttp3.tls.internal.der.ObjectIdentifiers.subjectAlternativeName
import okhttp3.tls.internal.der.TbsCertificate
import okhttp3.tls.internal.der.Validity
import okio.ByteString
Expand Down Expand Up @@ -187,8 +192,8 @@ class HeldCertificate(
class Builder {
private var notBefore = -1L
private var notAfter = -1L
private var cn: String? = null
private var ou: String? = null
private var commonName: String? = null
private var organizationalUnit: String? = null
private val altNames = mutableListOf<String>()
private var serialNumber: BigInteger? = null
private var keyPair: KeyPair? = null
Expand Down Expand Up @@ -240,12 +245,12 @@ class HeldCertificate(
* [rfc_2818]: https://tools.ietf.org/html/rfc2818
*/
fun commonName(cn: String) = apply {
this.cn = cn
this.commonName = cn
}

/** Sets the certificate's organizational unit (OU). If unset this field will be omitted. */
fun organizationalUnit(ou: String) = apply {
this.ou = ou
this.organizationalUnit = ou
}

/** Sets this certificate's serial number. If unset the serial number will be 1. */
Expand Down Expand Up @@ -378,16 +383,16 @@ class HeldCertificate(
private fun subject(): List<List<AttributeTypeAndValue>> {
val result = mutableListOf<List<AttributeTypeAndValue>>()

if (ou != null) {
if (organizationalUnit != null) {
result += listOf(AttributeTypeAndValue(
type = ObjectIdentifiers.organizationalUnitName,
value = ou
type = organizationalUnitName,
value = organizationalUnit
))
}

result += listOf(AttributeTypeAndValue(
type = ObjectIdentifiers.commonName,
value = cn ?: UUID.randomUUID().toString()
value = commonName ?: UUID.randomUUID().toString()
))

return result
Expand All @@ -407,11 +412,11 @@ class HeldCertificate(

if (maxIntermediateCas != -1) {
result += Extension(
extnID = ObjectIdentifiers.basicConstraints,
id = basicConstraints,
critical = true,
extnValue = BasicConstraints(
value = BasicConstraints(
ca = true,
pathLenConstraint = 3
maxIntermediateCas = maxIntermediateCas.toLong()
)
)
}
Expand All @@ -428,9 +433,9 @@ class HeldCertificate(
}
}
result += Extension(
extnID = ObjectIdentifiers.subjectAlternativeName,
id = subjectAlternativeName,
critical = true,
extnValue = extensionValue
value = extensionValue
)
}

Expand All @@ -440,20 +445,21 @@ class HeldCertificate(
private fun signatureAlgorithm(signedByKeyPair: KeyPair): AlgorithmIdentifier {
return when (signedByKeyPair.private) {
is RSAPrivateKey -> AlgorithmIdentifier(
algorithm = ObjectIdentifiers.sha256WithRSAEncryption,
algorithm = sha256WithRSAEncryption,
parameters = null
)
else -> AlgorithmIdentifier(
algorithm = ObjectIdentifiers.sha256withEcdsa,
algorithm = sha256withEcdsa,
parameters = ByteString.EMPTY
)
}
}

private fun generateKeyPair(): KeyPair {
val keyPairGenerator = KeyPairGenerator.getInstance(keyAlgorithm)
keyPairGenerator.initialize(keySize, SecureRandom())
return keyPairGenerator.generateKeyPair()
return KeyPairGenerator.getInstance(keyAlgorithm).run {
initialize(keySize, SecureRandom())
generateKeyPair()
}
}

companion object {
Expand Down
64 changes: 27 additions & 37 deletions okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/Adapters.kt
Original file line number Diff line number Diff line change
Expand Up @@ -239,35 +239,29 @@ internal object Adapters {
): BasicDerAdapter<T> {
val codec = object : BasicDerAdapter.Codec<T> {
override fun decode(reader: DerReader): T {
reader.pushTypeHint()
try {
return reader.withTypeHint {
val list = mutableListOf<Any?>()

while (list.size < members.size) {
val member = members[list.size]
list += member.readValue(reader)
list += member.fromDer(reader)
}

if (reader.hasNext()) {
throw IOException("unexpected ${reader.peekHeader()}")
throw IOException("unexpected ${reader.peekHeader()} at $reader")
}

return construct(list)
} finally {
reader.popTypeHint()
return@withTypeHint construct(list)
}
}

override fun encode(writer: DerWriter, value: T) {
val list = decompose(value)
writer.pushTypeHint()
try {
writer.withTypeHint {
for (i in list.indices) {
val adapter = members[i] as DerAdapter<Any?>
adapter.writeValue(writer, list[i])
adapter.toDer(writer, list[i])
}
} finally {
writer.popTypeHint()
}
}
}
Expand All @@ -285,19 +279,19 @@ internal object Adapters {
return object : DerAdapter<Pair<DerAdapter<*>, Any?>> {
override fun matches(header: DerHeader): Boolean = true

override fun readValue(reader: DerReader): Pair<DerAdapter<*>, Any?> {
override fun fromDer(reader: DerReader): Pair<DerAdapter<*>, Any?> {
val peekedHeader = reader.peekHeader()
?: throw IOException("expected a value")
?: throw IOException("expected a value at $reader")

val choice = choices.firstOrNull { it.matches(peekedHeader) }
?: throw IOException("expected a matching choice but was $peekedHeader")
?: throw IOException("expected a matching choice but was $peekedHeader at $reader")

return choice to choice.readValue(reader)
return choice to choice.fromDer(reader)
}

override fun writeValue(writer: DerWriter, value: Pair<DerAdapter<*>, Any?>) {
override fun toDer(writer: DerWriter, value: Pair<DerAdapter<*>, Any?>) {
val (adapter, v) = value
(adapter as DerAdapter<Any?>).writeValue(writer, v)
(adapter as DerAdapter<Any?>).toDer(writer, v)
}

override fun toString() = choices.joinToString(separator = " OR ")
Expand All @@ -317,27 +311,23 @@ internal object Adapters {
chooser: (Any?) -> DerAdapter<*>?
): DerAdapter<Any?> {
return object : DerAdapter<Any?> {
override fun matches(header: DerHeader): Boolean = true
override fun matches(header: DerHeader) = true

override fun writeValue(writer: DerWriter, value: Any?) {
override fun toDer(writer: DerWriter, value: Any?) {
// If we don't understand this hint, encode the body as a byte string. The byte string
// will include a tag and length header as a prefix.
val adapter = chooser(writer.typeHint)

if (adapter != null) {
(adapter as DerAdapter<Any?>).writeValue(writer, value)
} else {
writer.writeOctetString(value as ByteString)
val adapter = chooser(writer.typeHint) as DerAdapter<Any?>?
when {
adapter != null -> adapter.toDer(writer, value)
else -> writer.writeOctetString(value as ByteString)
}
}

override fun readValue(reader: DerReader): Any? {
val adapter = chooser(reader.typeHint)

if (adapter != null) {
return (adapter as DerAdapter<Any?>).readValue(reader)
} else {
return reader.readOctetString()
override fun fromDer(reader: DerReader): Any? {
val adapter = chooser(reader.typeHint) as DerAdapter<Any?>?
return when {
adapter != null -> adapter.fromDer(reader)
else -> reader.readOctetString()
}
}
}
Expand Down Expand Up @@ -370,7 +360,7 @@ internal object Adapters {
return object : DerAdapter<Any?> {
override fun matches(header: DerHeader): Boolean = true

override fun writeValue(writer: DerWriter, value: Any?) {
override fun toDer(writer: DerWriter, value: Any?) {
when {
isOptional && value == optionalValue -> {
// Write nothing.
Expand All @@ -385,20 +375,20 @@ internal object Adapters {
else -> {
for ((type, adapter) in choices) {
if (type.isInstance(value) || (value == null && type == Unit::class)) {
(adapter as DerAdapter<Any?>).writeValue(writer, value)
(adapter as DerAdapter<Any?>).toDer(writer, value)
return
}
}
}
}
}

override fun readValue(reader: DerReader): Any? {
override fun fromDer(reader: DerReader): Any? {
if (isOptional && !reader.hasNext()) return optionalValue

for ((_, adapter) in choices) {
if (adapter.matches(reader.peekHeader()!!)) {
return adapter.readValue(reader)
return adapter.fromDer(reader)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ internal data class BasicDerAdapter<T>(

override fun matches(header: DerHeader) = header.tagClass == tagClass && header.tag == tag

override fun readValue(reader: DerReader): T {
override fun fromDer(reader: DerReader): T {
val peekedHeader = reader.peekHeader()
if (peekedHeader == null || peekedHeader.tagClass != tagClass || peekedHeader.tag != tag) {
if (isOptional) return defaultValue as T
Expand All @@ -70,7 +70,7 @@ internal data class BasicDerAdapter<T>(
return result
}

override fun writeValue(writer: DerWriter, value: T) {
override fun toDer(writer: DerWriter, value: T) {
if (typeHint) {
writer.typeHint = value
}
Expand Down
Loading

0 comments on commit 353a52b

Please sign in to comment.