Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve code style for the new DER package #6157

Merged
merged 1 commit into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions okhttp-tls/src/main/kotlin/okhttp3/tls/HeldCertificate.kt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ import okhttp3.tls.internal.der.CertificateAdapters.generalNameDnsName
import okhttp3.tls.internal.der.CertificateAdapters.generalNameIpAddress
import okhttp3.tls.internal.der.Extension
import okhttp3.tls.internal.der.ObjectIdentifiers
import okhttp3.tls.internal.der.ObjectIdentifiers.basicConstraints
import okhttp3.tls.internal.der.ObjectIdentifiers.organizationalUnitName
import okhttp3.tls.internal.der.ObjectIdentifiers.sha256WithRSAEncryption
import okhttp3.tls.internal.der.ObjectIdentifiers.sha256withEcdsa
import okhttp3.tls.internal.der.ObjectIdentifiers.subjectAlternativeName
import okhttp3.tls.internal.der.TbsCertificate
import okhttp3.tls.internal.der.Validity
import okio.ByteString
Expand Down Expand Up @@ -187,8 +192,8 @@ class HeldCertificate(
class Builder {
private var notBefore = -1L
private var notAfter = -1L
private var cn: String? = null
private var ou: String? = null
private var commonName: String? = null
private var organizationalUnit: String? = null
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private val altNames = mutableListOf<String>()
private var serialNumber: BigInteger? = null
private var keyPair: KeyPair? = null
Expand Down Expand Up @@ -240,12 +245,12 @@ class HeldCertificate(
* [rfc_2818]: https://tools.ietf.org/html/rfc2818
*/
fun commonName(cn: String) = apply {
this.cn = cn
this.commonName = cn
}

/** Sets the certificate's organizational unit (OU). If unset this field will be omitted. */
fun organizationalUnit(ou: String) = apply {
this.ou = ou
this.organizationalUnit = ou
}

/** Sets this certificate's serial number. If unset the serial number will be 1. */
Expand Down Expand Up @@ -378,16 +383,16 @@ class HeldCertificate(
private fun subject(): List<List<AttributeTypeAndValue>> {
val result = mutableListOf<List<AttributeTypeAndValue>>()

if (ou != null) {
if (organizationalUnit != null) {
result += listOf(AttributeTypeAndValue(
type = ObjectIdentifiers.organizationalUnitName,
value = ou
type = organizationalUnitName,
value = organizationalUnit
))
}

result += listOf(AttributeTypeAndValue(
type = ObjectIdentifiers.commonName,
value = cn ?: UUID.randomUUID().toString()
value = commonName ?: UUID.randomUUID().toString()
))

return result
Expand All @@ -407,11 +412,11 @@ class HeldCertificate(

if (maxIntermediateCas != -1) {
result += Extension(
extnID = ObjectIdentifiers.basicConstraints,
id = basicConstraints,
critical = true,
extnValue = BasicConstraints(
value = BasicConstraints(
ca = true,
pathLenConstraint = 3
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only behavior change. This was a copy-paste bug and the test happens to be asserting that the sample value is exactly 3!

maxIntermediateCas = maxIntermediateCas.toLong()
)
)
}
Expand All @@ -428,9 +433,9 @@ class HeldCertificate(
}
}
result += Extension(
extnID = ObjectIdentifiers.subjectAlternativeName,
id = subjectAlternativeName,
critical = true,
extnValue = extensionValue
value = extensionValue
)
}

Expand All @@ -440,20 +445,21 @@ class HeldCertificate(
private fun signatureAlgorithm(signedByKeyPair: KeyPair): AlgorithmIdentifier {
return when (signedByKeyPair.private) {
is RSAPrivateKey -> AlgorithmIdentifier(
algorithm = ObjectIdentifiers.sha256WithRSAEncryption,
algorithm = sha256WithRSAEncryption,
parameters = null
)
else -> AlgorithmIdentifier(
algorithm = ObjectIdentifiers.sha256withEcdsa,
algorithm = sha256withEcdsa,
parameters = ByteString.EMPTY
)
}
}

private fun generateKeyPair(): KeyPair {
val keyPairGenerator = KeyPairGenerator.getInstance(keyAlgorithm)
keyPairGenerator.initialize(keySize, SecureRandom())
return keyPairGenerator.generateKeyPair()
return KeyPairGenerator.getInstance(keyAlgorithm).run {
initialize(keySize, SecureRandom())
generateKeyPair()
}
}

companion object {
Expand Down
64 changes: 27 additions & 37 deletions okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/Adapters.kt
Original file line number Diff line number Diff line change
Expand Up @@ -239,35 +239,29 @@ internal object Adapters {
): BasicDerAdapter<T> {
val codec = object : BasicDerAdapter.Codec<T> {
override fun decode(reader: DerReader): T {
reader.pushTypeHint()
try {
return reader.withTypeHint {
val list = mutableListOf<Any?>()

while (list.size < members.size) {
val member = members[list.size]
list += member.readValue(reader)
list += member.fromDer(reader)
}

if (reader.hasNext()) {
throw IOException("unexpected ${reader.peekHeader()}")
throw IOException("unexpected ${reader.peekHeader()} at $reader")
}

return construct(list)
} finally {
reader.popTypeHint()
return@withTypeHint construct(list)
}
}

override fun encode(writer: DerWriter, value: T) {
val list = decompose(value)
writer.pushTypeHint()
try {
writer.withTypeHint {
for (i in list.indices) {
val adapter = members[i] as DerAdapter<Any?>
adapter.writeValue(writer, list[i])
adapter.toDer(writer, list[i])
}
} finally {
writer.popTypeHint()
}
}
}
Expand All @@ -285,19 +279,19 @@ internal object Adapters {
return object : DerAdapter<Pair<DerAdapter<*>, Any?>> {
override fun matches(header: DerHeader): Boolean = true

override fun readValue(reader: DerReader): Pair<DerAdapter<*>, Any?> {
override fun fromDer(reader: DerReader): Pair<DerAdapter<*>, Any?> {
val peekedHeader = reader.peekHeader()
?: throw IOException("expected a value")
?: throw IOException("expected a value at $reader")

val choice = choices.firstOrNull { it.matches(peekedHeader) }
?: throw IOException("expected a matching choice but was $peekedHeader")
?: throw IOException("expected a matching choice but was $peekedHeader at $reader")

return choice to choice.readValue(reader)
return choice to choice.fromDer(reader)
}

override fun writeValue(writer: DerWriter, value: Pair<DerAdapter<*>, Any?>) {
override fun toDer(writer: DerWriter, value: Pair<DerAdapter<*>, Any?>) {
val (adapter, v) = value
(adapter as DerAdapter<Any?>).writeValue(writer, v)
(adapter as DerAdapter<Any?>).toDer(writer, v)
}

override fun toString() = choices.joinToString(separator = " OR ")
Expand All @@ -317,27 +311,23 @@ internal object Adapters {
chooser: (Any?) -> DerAdapter<*>?
): DerAdapter<Any?> {
return object : DerAdapter<Any?> {
override fun matches(header: DerHeader): Boolean = true
override fun matches(header: DerHeader) = true

override fun writeValue(writer: DerWriter, value: Any?) {
override fun toDer(writer: DerWriter, value: Any?) {
// If we don't understand this hint, encode the body as a byte string. The byte string
// will include a tag and length header as a prefix.
val adapter = chooser(writer.typeHint)

if (adapter != null) {
(adapter as DerAdapter<Any?>).writeValue(writer, value)
} else {
writer.writeOctetString(value as ByteString)
val adapter = chooser(writer.typeHint) as DerAdapter<Any?>?
when {
adapter != null -> adapter.toDer(writer, value)
else -> writer.writeOctetString(value as ByteString)
}
}

override fun readValue(reader: DerReader): Any? {
val adapter = chooser(reader.typeHint)

if (adapter != null) {
return (adapter as DerAdapter<Any?>).readValue(reader)
} else {
return reader.readOctetString()
override fun fromDer(reader: DerReader): Any? {
val adapter = chooser(reader.typeHint) as DerAdapter<Any?>?
return when {
adapter != null -> adapter.fromDer(reader)
else -> reader.readOctetString()
}
}
}
Expand Down Expand Up @@ -370,7 +360,7 @@ internal object Adapters {
return object : DerAdapter<Any?> {
override fun matches(header: DerHeader): Boolean = true

override fun writeValue(writer: DerWriter, value: Any?) {
override fun toDer(writer: DerWriter, value: Any?) {
when {
isOptional && value == optionalValue -> {
// Write nothing.
Expand All @@ -385,20 +375,20 @@ internal object Adapters {
else -> {
for ((type, adapter) in choices) {
if (type.isInstance(value) || (value == null && type == Unit::class)) {
(adapter as DerAdapter<Any?>).writeValue(writer, value)
(adapter as DerAdapter<Any?>).toDer(writer, value)
return
}
}
}
}
}

override fun readValue(reader: DerReader): Any? {
override fun fromDer(reader: DerReader): Any? {
if (isOptional && !reader.hasNext()) return optionalValue

for ((_, adapter) in choices) {
if (adapter.matches(reader.peekHeader()!!)) {
return adapter.readValue(reader)
return adapter.fromDer(reader)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ internal data class BasicDerAdapter<T>(

override fun matches(header: DerHeader) = header.tagClass == tagClass && header.tag == tag

override fun readValue(reader: DerReader): T {
override fun fromDer(reader: DerReader): T {
val peekedHeader = reader.peekHeader()
if (peekedHeader == null || peekedHeader.tagClass != tagClass || peekedHeader.tag != tag) {
if (isOptional) return defaultValue as T
Expand All @@ -70,7 +70,7 @@ internal data class BasicDerAdapter<T>(
return result
}

override fun writeValue(writer: DerWriter, value: T) {
override fun toDer(writer: DerWriter, value: T) {
if (typeHint) {
writer.typeHint = value
}
Expand Down
Loading