Skip to content

Commit

Permalink
feat(PrivateNodeRegistration): Add support for session keys (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
gnarea authored Nov 17, 2021
1 parent c068c79 commit 6f9c2dd
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package tech.relaycorp.relaynet.messages.control

import org.bouncycastle.asn1.ASN1TaggedObject
import org.bouncycastle.asn1.DEROctetString
import org.bouncycastle.asn1.DERSequence
import tech.relaycorp.relaynet.SessionKey
import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.wrappers.KeyException
import tech.relaycorp.relaynet.wrappers.asn1.ASN1Exception
import tech.relaycorp.relaynet.wrappers.asn1.ASN1Utils
import tech.relaycorp.relaynet.wrappers.deserializeECPublicKey
import tech.relaycorp.relaynet.wrappers.x509.Certificate
import tech.relaycorp.relaynet.wrappers.x509.CertificateException

Expand All @@ -16,21 +20,35 @@ import tech.relaycorp.relaynet.wrappers.x509.CertificateException
*
* @param privateNodeCertificate The certificate of the private node
* @param gatewayCertificate The certificate of the gateway acting as server
* @param gatewaySessionKey The session key of the gateway acting as server
*/
class PrivateNodeRegistration(
val privateNodeCertificate: Certificate,
val gatewayCertificate: Certificate
val gatewayCertificate: Certificate,
val gatewaySessionKey: SessionKey? = null,
) {
/**
* Serialize registration.
*/
fun serialize(): ByteArray {
val nodeCertificateASN1 = DEROctetString(privateNodeCertificate.serialize())
val gatewayCertificateASN1 = DEROctetString(gatewayCertificate.serialize())
return ASN1Utils.serializeSequence(
listOf(nodeCertificateASN1, gatewayCertificateASN1),
false
)
val gatewaySessionKeyASN1 = if (gatewaySessionKey != null) {
ASN1Utils.makeSequence(
listOf(
DEROctetString(gatewaySessionKey.keyId),
DEROctetString(gatewaySessionKey.publicKey.encoded),
),
false
)
} else {
null
}
val rootSequence = listOf(
nodeCertificateASN1,
gatewayCertificateASN1
) + listOfNotNull(gatewaySessionKeyASN1)
return ASN1Utils.serializeSequence(rootSequence, false)
}

companion object {
Expand Down Expand Up @@ -66,7 +84,34 @@ class PrivateNodeRegistration(
exc
)
}
return PrivateNodeRegistration(nodeCertificate, gatewayCertificate)
val gatewaySessionKey =
if (3 <= sequence.size) getSessionKeyFromSequence(sequence[2]) else null
return PrivateNodeRegistration(nodeCertificate, gatewayCertificate, gatewaySessionKey)
}

private fun getSessionKeyFromSequence(sessionKeyASN1: ASN1TaggedObject): SessionKey {
val sessionKeySequence = DERSequence.getInstance(sessionKeyASN1, false)
if (sessionKeySequence.size() < 2) {
throw InvalidMessageException(
"Session key SEQUENCE should have at least 2 items " +
"(got ${sessionKeySequence.size()})"
)
}
val sessionKeyId = ASN1Utils.getOctetString(
sessionKeySequence.getObjectAt(0) as ASN1TaggedObject,
).octets

val sessionPublicKeyASN1 =
ASN1Utils.getOctetString(sessionKeySequence.getObjectAt(1) as ASN1TaggedObject)
val sessionPublicKey = try {
sessionPublicKeyASN1.octets.deserializeECPublicKey()
} catch (exc: KeyException) {
throw InvalidMessageException(
"Session key is not a valid ECDH public key",
exc
)
}
return SessionKey(sessionKeyId, sessionPublicKey)
}

@Throws(CertificateException::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,27 @@ package tech.relaycorp.relaynet.messages.control

import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull
import kotlin.test.assertTrue
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.ASN1TaggedObject
import org.bouncycastle.asn1.DERNull
import org.bouncycastle.asn1.DEROctetString
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.relaynet.SessionKey
import tech.relaycorp.relaynet.SessionKeyPair
import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.utils.KeyPairSet
import tech.relaycorp.relaynet.utils.PDACertPath
import tech.relaycorp.relaynet.wrappers.KeyException
import tech.relaycorp.relaynet.wrappers.asn1.ASN1Exception
import tech.relaycorp.relaynet.wrappers.asn1.ASN1Utils
import tech.relaycorp.relaynet.wrappers.x509.CertificateException

class PrivateNodeRegistrationTest {
private val gatewaySessionKey = (SessionKeyPair.generate()).sessionKey

@Nested
inner class Serialize {
@Test
Expand Down Expand Up @@ -45,6 +54,57 @@ class PrivateNodeRegistrationTest {
gatewayCertificateASN1.octets.asList()
)
}

@Nested
inner class GatewaySessionKey {
@Test
fun `Session key should be absent from serialization if it does not exist`() {
val registration =
PrivateNodeRegistration(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW)

val serialization = registration.serialize()

val sequence = ASN1Utils.deserializeHeterogeneousSequence(serialization)
assertEquals(2, sequence.size)
}

@Test
fun `Key id should be serialized`() {
val registration = PrivateNodeRegistration(
PDACertPath.PRIVATE_ENDPOINT,
PDACertPath.PRIVATE_GW,
gatewaySessionKey
)

val serialization = registration.serialize()

val sequence = ASN1Utils.deserializeHeterogeneousSequence(serialization)
val sessionKeyASN1 = ASN1Sequence.getInstance(sequence[2], false)
val keyIdASN1 =
ASN1Utils.getOctetString(sessionKeyASN1.getObjectAt(0) as ASN1TaggedObject)
assertEquals(gatewaySessionKey.keyId.asList(), keyIdASN1.octets.asList())
}

@Test
fun `Public key should be serialized`() {
val registration = PrivateNodeRegistration(
PDACertPath.PRIVATE_ENDPOINT,
PDACertPath.PRIVATE_GW,
gatewaySessionKey
)

val serialization = registration.serialize()

val sequence = ASN1Utils.deserializeHeterogeneousSequence(serialization)
val sessionKeyASN1 = ASN1Sequence.getInstance(sequence[2], false)
val sessionPublicKeyASN1 =
ASN1Utils.getOctetString(sessionKeyASN1.getObjectAt(1) as ASN1TaggedObject)
assertEquals(
gatewaySessionKey.publicKey.encoded.asList(),
sessionPublicKeyASN1.octets.asList()
)
}
}
}

@Nested
Expand Down Expand Up @@ -114,7 +174,7 @@ class PrivateNodeRegistrationTest {
}

@Test
fun `Valid registration should be accepted`() {
fun `Valid registration without session key should be accepted`() {
val registration =
PrivateNodeRegistration(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW)
val serialization = registration.serialize()
Expand All @@ -129,6 +189,71 @@ class PrivateNodeRegistrationTest {
PDACertPath.PRIVATE_GW,
registrationDeserialized.gatewayCertificate
)
assertNull(registrationDeserialized.gatewaySessionKey)
}

@Nested
inner class GatewaySessionKey {
@Test
fun `SEQUENCE should contain at least two items`() {
val invalidSerialization = ASN1Utils.serializeSequence(
listOf(
DEROctetString(PDACertPath.PRIVATE_ENDPOINT.serialize()),
DEROctetString(PDACertPath.PRIVATE_GW.serialize()),
ASN1Utils.makeSequence(listOf(DEROctetString(gatewaySessionKey.keyId)))
),
false
)

val exception = assertThrows<InvalidMessageException> {
PrivateNodeRegistration.deserialize(invalidSerialization)
}

assertEquals(
"Session key SEQUENCE should have at least 2 items (got 1)",
exception.message
)
}

@Test
fun `Session key should be a valid ECDH public key`() {
val invalidRegistration = PrivateNodeRegistration(
PDACertPath.PRIVATE_ENDPOINT,
PDACertPath.PRIVATE_GW,
SessionKey(
gatewaySessionKey.keyId,
KeyPairSet.PRIVATE_ENDPOINT.public, // Invalid: Not an ECDH key.
)
)
val invalidSerialization = invalidRegistration.serialize()

val exception = assertThrows<InvalidMessageException> {
PrivateNodeRegistration.deserialize(invalidSerialization)
}

assertEquals(
"Session key is not a valid ECDH public key",
exception.message
)
assertTrue(exception.cause is KeyException)
}

@Test
fun `Valid registration with session key should be accepted`() {
val registration = PrivateNodeRegistration(
PDACertPath.PRIVATE_ENDPOINT,
PDACertPath.PRIVATE_GW,
gatewaySessionKey
)
val serialization = registration.serialize()

val registrationDeserialized = PrivateNodeRegistration.deserialize(serialization)

assertEquals(
gatewaySessionKey,
registrationDeserialized.gatewaySessionKey
)
}
}
}
}

0 comments on commit 6f9c2dd

Please sign in to comment.